Update weight watcher codes
This commit is contained in:
@@ -411,7 +411,11 @@ class ArchResults(object):
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
return {seed: self.all_results[(dataset, seed)].get_net_param() for seed in x_seeds}
|
||||
else:
|
||||
return self.all_results[(dataset, seed)].get_net_param()
|
||||
xkey = (dataset, seed)
|
||||
if xkey in self.all_results:
|
||||
return self.all_results[xkey].get_net_param()
|
||||
else:
|
||||
raise ValueError('key={:} not in {:}'.format(xkey, list(self.all_results.keys())))
|
||||
|
||||
def reset_latency(self, dataset: Text, seed: Union[None, Text], latency: float) -> None:
|
||||
"""This function is used to reset the latency in all corresponding ResultsCount(s)."""
|
||||
|
Reference in New Issue
Block a user