Add visualize codes for Q
This commit is contained in:
@@ -4,7 +4,7 @@ import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch.nn as nn
|
||||
|
||||
# from utils import obtain_accuracy
|
||||
# modules in AutoDL
|
||||
from models import CellStructure
|
||||
from log_utils import time_string
|
||||
|
||||
@@ -56,11 +56,20 @@ def evaluate_one_shot(model, xloader, api, cal_mode, seed=111):
|
||||
correct = (preds == targets.cuda()).float()
|
||||
accuracies.append(correct.mean().item())
|
||||
if idx != 0 and (idx % 500 == 0 or idx + 1 == len(archs)):
|
||||
cor_accs_valid = np.corrcoef(accuracies, gt_accs_10_valid[: idx + 1])[0, 1]
|
||||
cor_accs_test = np.corrcoef(accuracies, gt_accs_10_test[: idx + 1])[0, 1]
|
||||
cor_accs_valid = np.corrcoef(accuracies, gt_accs_10_valid[: idx + 1])[
|
||||
0, 1
|
||||
]
|
||||
cor_accs_test = np.corrcoef(accuracies, gt_accs_10_test[: idx + 1])[
|
||||
0, 1
|
||||
]
|
||||
print(
|
||||
"{:} {:05d}/{:05d} mode={:5s}, correlation : accs={:.5f} for CIFAR-10 valid, {:.5f} for CIFAR-10 test.".format(
|
||||
time_string(), idx, len(archs), "Train" if cal_mode else "Eval", cor_accs_valid, cor_accs_test
|
||||
time_string(),
|
||||
idx,
|
||||
len(archs),
|
||||
"Train" if cal_mode else "Eval",
|
||||
cor_accs_valid,
|
||||
cor_accs_test,
|
||||
)
|
||||
)
|
||||
model.load_state_dict(weights)
|
||||
|
Reference in New Issue
Block a user