Update SuperMLP
This commit is contained in:
@@ -36,7 +36,12 @@ def search_train_v2(
|
||||
logger,
|
||||
):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
|
||||
base_losses, arch_losses, top1, top5 = (
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
)
|
||||
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
|
||||
epoch_str, flop_need, flop_weight, flop_tolerant = (
|
||||
extra_info["epoch-str"],
|
||||
@@ -46,10 +51,16 @@ def search_train_v2(
|
||||
)
|
||||
|
||||
network.train()
|
||||
logger.log("[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(epoch_str, flop_need, flop_weight))
|
||||
logger.log(
|
||||
"[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(
|
||||
epoch_str, flop_need, flop_weight
|
||||
)
|
||||
)
|
||||
end = time.time()
|
||||
network.apply(change_key("search_mode", "search"))
|
||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
|
||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(
|
||||
search_loader
|
||||
):
|
||||
scheduler.update(None, 1.0 * step / len(search_loader))
|
||||
# calculate prediction and loss
|
||||
base_targets = base_targets.cuda(non_blocking=True)
|
||||
@@ -73,7 +84,9 @@ def search_train_v2(
|
||||
arch_optimizer.zero_grad()
|
||||
logits, expected_flop = network(arch_inputs)
|
||||
flop_cur = network.module.get_flop("genotype", None, None)
|
||||
flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant)
|
||||
flop_loss, flop_loss_scale = get_flop_loss(
|
||||
expected_flop, flop_cur, flop_need, flop_tolerant
|
||||
)
|
||||
acls_loss = criterion(logits, arch_targets)
|
||||
arch_loss = acls_loss + flop_loss * flop_weight
|
||||
arch_loss.backward()
|
||||
@@ -88,7 +101,11 @@ def search_train_v2(
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
if step % print_freq == 0 or (step + 1) == len(search_loader):
|
||||
Sstr = "**TRAIN** " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader))
|
||||
Sstr = (
|
||||
"**TRAIN** "
|
||||
+ time_string()
|
||||
+ " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader))
|
||||
)
|
||||
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
||||
batch_time=batch_time, data_time=data_time
|
||||
)
|
||||
|
Reference in New Issue
Block a user