Update the test codes for NAS-Bench-API

This commit is contained in:
D-X-Y
2020-07-01 12:29:46 +00:00
parent 1906454a73
commit a45808b8e6
5 changed files with 287 additions and 212 deletions

View File

@@ -660,15 +660,21 @@ class ResultsCount(object):
"""Get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument)."""
if iepoch is None: iepoch = self.epochs-1
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
if isinstance(self.eval_times,dict) and len(self.eval_times) > 0:
xtime = self.eval_times['{:}@{:}'.format(name,iepoch)]
atime = sum([self.eval_times['{:}@{:}'.format(name,i)] for i in range(iepoch+1)])
else: xtime, atime = None, None
return {'iepoch' : iepoch,
'loss' : self.eval_losses['{:}@{:}'.format(name,iepoch)],
'accuracy': self.eval_acc1es['{:}@{:}'.format(name,iepoch)],
'cur_time': xtime,
'all_time': atime}
def _internal_query(xname):
if isinstance(self.eval_times,dict) and len(self.eval_times) > 0:
xtime = self.eval_times['{:}@{:}'.format(xname, iepoch)]
atime = sum([self.eval_times['{:}@{:}'.format(xname, i)] for i in range(iepoch+1)])
else:
xtime, atime = None, None
return {'iepoch' : iepoch,
'loss' : self.eval_losses['{:}@{:}'.format(xname, iepoch)],
'accuracy': self.eval_acc1es['{:}@{:}'.format(xname, iepoch)],
'cur_time': xtime,
'all_time': atime}
if name == 'valid':
return _internal_query('x-valid')
else:
return _internal_query(name)
def get_net_param(self, clone=False):
if clone: return copy.deepcopy(self.net_state_dict)