Update SuperMLP
This commit is contained in:
@@ -36,7 +36,12 @@ def search_train(
|
||||
logger,
|
||||
):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
|
||||
base_losses, arch_losses, top1, top5 = (
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
)
|
||||
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
|
||||
epoch_str, flop_need, flop_weight, flop_tolerant = (
|
||||
extra_info["epoch-str"],
|
||||
@@ -46,10 +51,16 @@ def search_train(
|
||||
)
|
||||
|
||||
network.train()
|
||||
logger.log("[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(epoch_str, flop_need, flop_weight))
|
||||
logger.log(
|
||||
"[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(
|
||||
epoch_str, flop_need, flop_weight
|
||||
)
|
||||
)
|
||||
end = time.time()
|
||||
network.apply(change_key("search_mode", "search"))
|
||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
|
||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(
|
||||
search_loader
|
||||
):
|
||||
scheduler.update(None, 1.0 * step / len(search_loader))
|
||||
# calculate prediction and loss
|
||||
base_targets = base_targets.cuda(non_blocking=True)
|
||||
@@ -75,7 +86,9 @@ def search_train(
|
||||
arch_optimizer.zero_grad()
|
||||
logits, expected_flop = network(arch_inputs)
|
||||
flop_cur = network.module.get_flop("genotype", None, None)
|
||||
flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant)
|
||||
flop_loss, flop_loss_scale = get_flop_loss(
|
||||
expected_flop, flop_cur, flop_need, flop_tolerant
|
||||
)
|
||||
acls_loss = criterion(logits, arch_targets)
|
||||
arch_loss = acls_loss + flop_loss * flop_weight
|
||||
arch_loss.backward()
|
||||
@@ -90,7 +103,11 @@ def search_train(
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
if step % print_freq == 0 or (step + 1) == len(search_loader):
|
||||
Sstr = "**TRAIN** " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader))
|
||||
Sstr = (
|
||||
"**TRAIN** "
|
||||
+ time_string()
|
||||
+ " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader))
|
||||
)
|
||||
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
||||
batch_time=batch_time, data_time=data_time
|
||||
)
|
||||
@@ -153,7 +170,11 @@ def search_valid(xloader, network, criterion, extra_info, print_freq, logger):
|
||||
end = time.time()
|
||||
|
||||
if i % print_freq == 0 or (i + 1) == len(xloader):
|
||||
Sstr = "**VALID** " + time_string() + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader))
|
||||
Sstr = (
|
||||
"**VALID** "
|
||||
+ time_string()
|
||||
+ " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader))
|
||||
)
|
||||
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
||||
batch_time=batch_time, data_time=data_time
|
||||
)
|
||||
@@ -165,7 +186,11 @@ def search_valid(xloader, network, criterion, extra_info, print_freq, logger):
|
||||
|
||||
logger.log(
|
||||
" **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
|
||||
top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg
|
||||
top1=top1,
|
||||
top5=top5,
|
||||
error1=100 - top1.avg,
|
||||
error5=100 - top5.avg,
|
||||
loss=losses.avg,
|
||||
)
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user