Update the test codes for NAS-Bench-API
This commit is contained in:
@@ -660,15 +660,21 @@ class ResultsCount(object):
|
||||
"""Get the evaluation information ; there could be multiple evaluation sets (identified by the 'name' argument)."""
|
||||
if iepoch is None: iepoch = self.epochs-1
|
||||
assert 0 <= iepoch < self.epochs, 'invalid iepoch={:} < {:}'.format(iepoch, self.epochs)
|
||||
if isinstance(self.eval_times,dict) and len(self.eval_times) > 0:
|
||||
xtime = self.eval_times['{:}@{:}'.format(name,iepoch)]
|
||||
atime = sum([self.eval_times['{:}@{:}'.format(name,i)] for i in range(iepoch+1)])
|
||||
else: xtime, atime = None, None
|
||||
return {'iepoch' : iepoch,
|
||||
'loss' : self.eval_losses['{:}@{:}'.format(name,iepoch)],
|
||||
'accuracy': self.eval_acc1es['{:}@{:}'.format(name,iepoch)],
|
||||
'cur_time': xtime,
|
||||
'all_time': atime}
|
||||
def _internal_query(xname):
|
||||
if isinstance(self.eval_times,dict) and len(self.eval_times) > 0:
|
||||
xtime = self.eval_times['{:}@{:}'.format(xname, iepoch)]
|
||||
atime = sum([self.eval_times['{:}@{:}'.format(xname, i)] for i in range(iepoch+1)])
|
||||
else:
|
||||
xtime, atime = None, None
|
||||
return {'iepoch' : iepoch,
|
||||
'loss' : self.eval_losses['{:}@{:}'.format(xname, iepoch)],
|
||||
'accuracy': self.eval_acc1es['{:}@{:}'.format(xname, iepoch)],
|
||||
'cur_time': xtime,
|
||||
'all_time': atime}
|
||||
if name == 'valid':
|
||||
return _internal_query('x-valid')
|
||||
else:
|
||||
return _internal_query(name)
|
||||
|
||||
def get_net_param(self, clone=False):
|
||||
if clone: return copy.deepcopy(self.net_state_dict)
|
||||
|
Reference in New Issue
Block a user