Update SuperMLP

This commit is contained in:
D-X-Y
2021-03-19 23:57:23 +08:00
parent 31b8122cc1
commit 0c56a729ad
13 changed files with 412 additions and 85 deletions

View File

@@ -7,22 +7,63 @@ from log_utils import time_string
from utils import obtain_accuracy
def basic_train(xloader, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger):
def basic_train(
xloader,
network,
criterion,
scheduler,
optimizer,
optim_config,
extra_info,
print_freq,
logger,
):
loss, acc1, acc5 = procedure(
xloader, network, criterion, scheduler, optimizer, "train", optim_config, extra_info, print_freq, logger
xloader,
network,
criterion,
scheduler,
optimizer,
"train",
optim_config,
extra_info,
print_freq,
logger,
)
return loss, acc1, acc5
def basic_valid(xloader, network, criterion, optim_config, extra_info, print_freq, logger):
def basic_valid(
xloader, network, criterion, optim_config, extra_info, print_freq, logger
):
with torch.no_grad():
loss, acc1, acc5 = procedure(
xloader, network, criterion, None, None, "valid", None, extra_info, print_freq, logger
xloader,
network,
criterion,
None,
None,
"valid",
None,
extra_info,
print_freq,
logger,
)
return loss, acc1, acc5
def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger):
def procedure(
xloader,
network,
criterion,
scheduler,
optimizer,
mode,
config,
extra_info,
print_freq,
logger,
):
data_time, batch_time, losses, top1, top5 = (
AverageMeter(),
AverageMeter(),
@@ -39,7 +80,9 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, e
# logger.log('[{:5s}] config :: auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message()))
logger.log(
"[{:5s}] config :: auxiliary={:}".format(mode, config.auxiliary if hasattr(config, "auxiliary") else -1)
"[{:5s}] config :: auxiliary={:}".format(
mode, config.auxiliary if hasattr(config, "auxiliary") else -1
)
)
end = time.time()
for i, (inputs, targets) in enumerate(xloader):
@@ -55,7 +98,9 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, e
features, logits = network(inputs)
if isinstance(logits, list):
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(2, len(logits))
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(
2, len(logits)
)
logits, logits_aux = logits
else:
logits, logits_aux = logits, None
@@ -97,7 +142,12 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, e
logger.log(
" **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
mode=mode.upper(), top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg
mode=mode.upper(),
top1=top1,
top5=top5,
error1=100 - top1.avg,
error5=100 - top5.avg,
loss=losses.avg,
)
)
return losses.avg, top1.avg, top5.avg