Add int search space
This commit is contained in:
@@ -76,7 +76,9 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode):
|
||||
return losses.avg, top1.avg, top5.avg, batch_time.sum
|
||||
|
||||
|
||||
def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loaders, seed, logger):
|
||||
def evaluate_for_seed(
|
||||
arch_config, config, arch, train_loader, valid_loaders, seed, logger
|
||||
):
|
||||
|
||||
prepare_seed(seed) # random seed
|
||||
net = get_cell_based_tiny_net(
|
||||
@@ -94,14 +96,29 @@ def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loaders, se
|
||||
# net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
|
||||
flop, param = get_model_infos(net, config.xshape)
|
||||
logger.log("Network : {:}".format(net.get_message()), False)
|
||||
logger.log("{:} Seed-------------------------- {:} --------------------------".format(time_string(), seed))
|
||||
logger.log(
|
||||
"{:} Seed-------------------------- {:} --------------------------".format(
|
||||
time_string(), seed
|
||||
)
|
||||
)
|
||||
logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param))
|
||||
# train and valid
|
||||
optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), config)
|
||||
network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda()
|
||||
# start training
|
||||
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
|
||||
train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {}
|
||||
start_time, epoch_time, total_epoch = (
|
||||
time.time(),
|
||||
AverageMeter(),
|
||||
config.epochs + config.warmup,
|
||||
)
|
||||
(
|
||||
train_losses,
|
||||
train_acc1es,
|
||||
train_acc5es,
|
||||
valid_losses,
|
||||
valid_acc1es,
|
||||
valid_acc5es,
|
||||
) = ({}, {}, {}, {}, {}, {})
|
||||
train_times, valid_times = {}, {}
|
||||
for epoch in range(total_epoch):
|
||||
scheduler.update(epoch, 0.0)
|
||||
@@ -126,7 +143,9 @@ def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loaders, se
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True))
|
||||
need_time = "Time Left: {:}".format(
|
||||
convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True)
|
||||
)
|
||||
logger.log(
|
||||
"{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%]".format(
|
||||
time_string(),
|
||||
|
Reference in New Issue
Block a user