update GDAS

This commit is contained in:
D-X-Y
2019-11-19 11:58:04 +11:00
parent c3672648d7
commit 09d68c6375
20 changed files with 1176 additions and 90 deletions

View File

@@ -1,7 +1,7 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import os, sys, copy, torch, numpy as np
import os, sys, copy, random, torch, numpy as np
from collections import OrderedDict
@@ -149,7 +149,7 @@ class ArchResults(object):
lantencies = [result.get_latency() for result in results]
return np.mean(flops), np.mean(params), np.mean(lantencies)
def get_metrics(self, dataset, setname, iepoch=None):
def get_metrics(self, dataset, setname, iepoch=None, is_random=False):
x_seeds = self.dataset_seed[dataset]
results = [self.all_results[ (dataset, seed) ] for seed in x_seeds]
loss, accuracy = [], []
@@ -160,7 +160,11 @@ class ArchResults(object):
info = result.get_eval(setname, iepoch)
loss.append( info['loss'] )
accuracy.append( info['accuracy'] )
return float(np.mean(loss)), float(np.mean(accuracy))
if is_random:
index = random.randint(0, len(loss)-1)
return loss[index], accuracy[index]
else:
return float(np.mean(loss)), float(np.mean(accuracy))
def show(self, is_print=False):
return print_information(self, None, is_print)