Update SuperMLP

This commit is contained in:
D-X-Y
2021-03-19 23:57:23 +08:00
parent 31b8122cc1
commit 0c56a729ad
13 changed files with 412 additions and 85 deletions

View File

@@ -36,7 +36,12 @@ def search_train(
logger,
):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
base_losses, arch_losses, top1, top5 = (
AverageMeter(),
AverageMeter(),
AverageMeter(),
AverageMeter(),
)
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
epoch_str, flop_need, flop_weight, flop_tolerant = (
extra_info["epoch-str"],
@@ -46,10 +51,16 @@ def search_train(
)
network.train()
logger.log("[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(epoch_str, flop_need, flop_weight))
logger.log(
"[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(
epoch_str, flop_need, flop_weight
)
)
end = time.time()
network.apply(change_key("search_mode", "search"))
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(
search_loader
):
scheduler.update(None, 1.0 * step / len(search_loader))
# calculate prediction and loss
base_targets = base_targets.cuda(non_blocking=True)
@@ -75,7 +86,9 @@ def search_train(
arch_optimizer.zero_grad()
logits, expected_flop = network(arch_inputs)
flop_cur = network.module.get_flop("genotype", None, None)
flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant)
flop_loss, flop_loss_scale = get_flop_loss(
expected_flop, flop_cur, flop_need, flop_tolerant
)
acls_loss = criterion(logits, arch_targets)
arch_loss = acls_loss + flop_loss * flop_weight
arch_loss.backward()
@@ -90,7 +103,11 @@ def search_train(
batch_time.update(time.time() - end)
end = time.time()
if step % print_freq == 0 or (step + 1) == len(search_loader):
Sstr = "**TRAIN** " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader))
Sstr = (
"**TRAIN** "
+ time_string()
+ " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader))
)
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
@@ -153,7 +170,11 @@ def search_valid(xloader, network, criterion, extra_info, print_freq, logger):
end = time.time()
if i % print_freq == 0 or (i + 1) == len(xloader):
Sstr = "**VALID** " + time_string() + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader))
Sstr = (
"**VALID** "
+ time_string()
+ " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader))
)
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
@@ -165,7 +186,11 @@ def search_valid(xloader, network, criterion, extra_info, print_freq, logger):
logger.log(
" **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg
top1=top1,
top5=top5,
error1=100 - top1.avg,
error5=100 - top5.avg,
loss=losses.avg,
)
)