Update SuperMLP

This commit is contained in:
D-X-Y
2021-03-19 23:57:23 +08:00
parent 31b8122cc1
commit 0c56a729ad
13 changed files with 412 additions and 85 deletions

View File

@@ -10,7 +10,16 @@ from utils import obtain_accuracy
def simple_KD_train(
xloader, teacher, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger
xloader,
teacher,
network,
criterion,
scheduler,
optimizer,
optim_config,
extra_info,
print_freq,
logger,
):
loss, acc1, acc5 = procedure(
xloader,
@@ -28,25 +37,58 @@ def simple_KD_train(
return loss, acc1, acc5
def simple_KD_valid(xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger):
def simple_KD_valid(
xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger
):
with torch.no_grad():
loss, acc1, acc5 = procedure(
xloader, teacher, network, criterion, None, None, "valid", optim_config, extra_info, print_freq, logger
xloader,
teacher,
network,
criterion,
None,
None,
"valid",
optim_config,
extra_info,
print_freq,
logger,
)
return loss, acc1, acc5
def loss_KD_fn(
criterion, student_logits, teacher_logits, studentFeatures, teacherFeatures, targets, alpha, temperature
criterion,
student_logits,
teacher_logits,
studentFeatures,
teacherFeatures,
targets,
alpha,
temperature,
):
basic_loss = criterion(student_logits, targets) * (1.0 - alpha)
log_student = F.log_softmax(student_logits / temperature, dim=1)
sof_teacher = F.softmax(teacher_logits / temperature, dim=1)
KD_loss = F.kl_div(log_student, sof_teacher, reduction="batchmean") * (alpha * temperature * temperature)
KD_loss = F.kl_div(log_student, sof_teacher, reduction="batchmean") * (
alpha * temperature * temperature
)
return basic_loss + KD_loss
def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger):
def procedure(
xloader,
teacher,
network,
criterion,
scheduler,
optimizer,
mode,
config,
extra_info,
print_freq,
logger,
):
data_time, batch_time, losses, top1, top5 = (
AverageMeter(),
AverageMeter(),
@@ -65,7 +107,10 @@ def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode,
logger.log(
"[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]".format(
mode, config.auxiliary if hasattr(config, "auxiliary") else -1, config.KD_alpha, config.KD_temperature
mode,
config.auxiliary if hasattr(config, "auxiliary") else -1,
config.KD_alpha,
config.KD_temperature,
)
)
end = time.time()
@@ -82,7 +127,9 @@ def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode,
student_f, logits = network(inputs)
if isinstance(logits, list):
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(2, len(logits))
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(
2, len(logits)
)
logits, logits_aux = logits
else:
logits, logits_aux = logits, None
@@ -90,7 +137,14 @@ def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode,
teacher_f, teacher_logits = teacher(inputs)
loss = loss_KD_fn(
criterion, logits, teacher_logits, student_f, teacher_f, targets, config.KD_alpha, config.KD_temperature
criterion,
logits,
teacher_logits,
student_f,
teacher_f,
targets,
config.KD_alpha,
config.KD_temperature,
)
if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0:
loss_aux = criterion(logits_aux, targets)
@@ -139,7 +193,12 @@ def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode,
)
logger.log(
" **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
mode=mode.upper(), top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg
mode=mode.upper(),
top1=top1,
top5=top5,
error1=100 - top1.avg,
error5=100 - top5.avg,
loss=losses.avg,
)
)
return losses.avg, top1.avg, top5.avg