update GDAS
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, copy, torch, numpy as np
|
||||
import os, sys, copy, random, torch, numpy as np
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
@@ -149,7 +149,7 @@ class ArchResults(object):
|
||||
lantencies = [result.get_latency() for result in results]
|
||||
return np.mean(flops), np.mean(params), np.mean(lantencies)
|
||||
|
||||
def get_metrics(self, dataset, setname, iepoch=None):
|
||||
def get_metrics(self, dataset, setname, iepoch=None, is_random=False):
|
||||
x_seeds = self.dataset_seed[dataset]
|
||||
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
|
||||
loss, accuracy = [], []
|
||||
@@ -160,7 +160,11 @@ class ArchResults(object):
|
||||
info = result.get_eval(setname, iepoch)
|
||||
loss.append( info['loss'] )
|
||||
accuracy.append( info['accuracy'] )
|
||||
return float(np.mean(loss)), float(np.mean(accuracy))
|
||||
if is_random:
|
||||
index = random.randint(0, len(loss)-1)
|
||||
return loss[index], accuracy[index]
|
||||
else:
|
||||
return float(np.mean(loss)), float(np.mean(accuracy))
|
||||
|
||||
def show(self, is_print=False):
|
||||
return print_information(self, None, is_print)
|
||||
|
Reference in New Issue
Block a user