Update SuperMLP
This commit is contained in:
@@ -7,22 +7,63 @@ from log_utils import time_string
|
||||
from utils import obtain_accuracy
|
||||
|
||||
|
||||
def basic_train(xloader, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger):
|
||||
def basic_train(
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
optimizer,
|
||||
optim_config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
):
|
||||
loss, acc1, acc5 = procedure(
|
||||
xloader, network, criterion, scheduler, optimizer, "train", optim_config, extra_info, print_freq, logger
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
optimizer,
|
||||
"train",
|
||||
optim_config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
)
|
||||
return loss, acc1, acc5
|
||||
|
||||
|
||||
def basic_valid(xloader, network, criterion, optim_config, extra_info, print_freq, logger):
|
||||
def basic_valid(
|
||||
xloader, network, criterion, optim_config, extra_info, print_freq, logger
|
||||
):
|
||||
with torch.no_grad():
|
||||
loss, acc1, acc5 = procedure(
|
||||
xloader, network, criterion, None, None, "valid", None, extra_info, print_freq, logger
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
None,
|
||||
None,
|
||||
"valid",
|
||||
None,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
)
|
||||
return loss, acc1, acc5
|
||||
|
||||
|
||||
def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger):
|
||||
def procedure(
|
||||
xloader,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
optimizer,
|
||||
mode,
|
||||
config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
):
|
||||
data_time, batch_time, losses, top1, top5 = (
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
@@ -39,7 +80,9 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, e
|
||||
|
||||
# logger.log('[{:5s}] config :: auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message()))
|
||||
logger.log(
|
||||
"[{:5s}] config :: auxiliary={:}".format(mode, config.auxiliary if hasattr(config, "auxiliary") else -1)
|
||||
"[{:5s}] config :: auxiliary={:}".format(
|
||||
mode, config.auxiliary if hasattr(config, "auxiliary") else -1
|
||||
)
|
||||
)
|
||||
end = time.time()
|
||||
for i, (inputs, targets) in enumerate(xloader):
|
||||
@@ -55,7 +98,9 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, e
|
||||
|
||||
features, logits = network(inputs)
|
||||
if isinstance(logits, list):
|
||||
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(2, len(logits))
|
||||
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(
|
||||
2, len(logits)
|
||||
)
|
||||
logits, logits_aux = logits
|
||||
else:
|
||||
logits, logits_aux = logits, None
|
||||
@@ -97,7 +142,12 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, e
|
||||
|
||||
logger.log(
|
||||
" **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
|
||||
mode=mode.upper(), top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg
|
||||
mode=mode.upper(),
|
||||
top1=top1,
|
||||
top5=top5,
|
||||
error1=100 - top1.avg,
|
||||
error5=100 - top5.avg,
|
||||
loss=losses.avg,
|
||||
)
|
||||
)
|
||||
return losses.avg, top1.avg, top5.avg
|
||||
|
Reference in New Issue
Block a user