Update SuperMLP
This commit is contained in:
@@ -10,7 +10,16 @@ from utils import obtain_accuracy
|
||||
|
||||
|
||||
def simple_KD_train(
|
||||
xloader, teacher, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger
|
||||
xloader,
|
||||
teacher,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
optimizer,
|
||||
optim_config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
):
|
||||
loss, acc1, acc5 = procedure(
|
||||
xloader,
|
||||
@@ -28,25 +37,58 @@ def simple_KD_train(
|
||||
return loss, acc1, acc5
|
||||
|
||||
|
||||
def simple_KD_valid(xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger):
|
||||
def simple_KD_valid(
|
||||
xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger
|
||||
):
|
||||
with torch.no_grad():
|
||||
loss, acc1, acc5 = procedure(
|
||||
xloader, teacher, network, criterion, None, None, "valid", optim_config, extra_info, print_freq, logger
|
||||
xloader,
|
||||
teacher,
|
||||
network,
|
||||
criterion,
|
||||
None,
|
||||
None,
|
||||
"valid",
|
||||
optim_config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
)
|
||||
return loss, acc1, acc5
|
||||
|
||||
|
||||
def loss_KD_fn(
|
||||
criterion, student_logits, teacher_logits, studentFeatures, teacherFeatures, targets, alpha, temperature
|
||||
criterion,
|
||||
student_logits,
|
||||
teacher_logits,
|
||||
studentFeatures,
|
||||
teacherFeatures,
|
||||
targets,
|
||||
alpha,
|
||||
temperature,
|
||||
):
|
||||
basic_loss = criterion(student_logits, targets) * (1.0 - alpha)
|
||||
log_student = F.log_softmax(student_logits / temperature, dim=1)
|
||||
sof_teacher = F.softmax(teacher_logits / temperature, dim=1)
|
||||
KD_loss = F.kl_div(log_student, sof_teacher, reduction="batchmean") * (alpha * temperature * temperature)
|
||||
KD_loss = F.kl_div(log_student, sof_teacher, reduction="batchmean") * (
|
||||
alpha * temperature * temperature
|
||||
)
|
||||
return basic_loss + KD_loss
|
||||
|
||||
|
||||
def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger):
|
||||
def procedure(
|
||||
xloader,
|
||||
teacher,
|
||||
network,
|
||||
criterion,
|
||||
scheduler,
|
||||
optimizer,
|
||||
mode,
|
||||
config,
|
||||
extra_info,
|
||||
print_freq,
|
||||
logger,
|
||||
):
|
||||
data_time, batch_time, losses, top1, top5 = (
|
||||
AverageMeter(),
|
||||
AverageMeter(),
|
||||
@@ -65,7 +107,10 @@ def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode,
|
||||
|
||||
logger.log(
|
||||
"[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]".format(
|
||||
mode, config.auxiliary if hasattr(config, "auxiliary") else -1, config.KD_alpha, config.KD_temperature
|
||||
mode,
|
||||
config.auxiliary if hasattr(config, "auxiliary") else -1,
|
||||
config.KD_alpha,
|
||||
config.KD_temperature,
|
||||
)
|
||||
)
|
||||
end = time.time()
|
||||
@@ -82,7 +127,9 @@ def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode,
|
||||
|
||||
student_f, logits = network(inputs)
|
||||
if isinstance(logits, list):
|
||||
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(2, len(logits))
|
||||
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(
|
||||
2, len(logits)
|
||||
)
|
||||
logits, logits_aux = logits
|
||||
else:
|
||||
logits, logits_aux = logits, None
|
||||
@@ -90,7 +137,14 @@ def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode,
|
||||
teacher_f, teacher_logits = teacher(inputs)
|
||||
|
||||
loss = loss_KD_fn(
|
||||
criterion, logits, teacher_logits, student_f, teacher_f, targets, config.KD_alpha, config.KD_temperature
|
||||
criterion,
|
||||
logits,
|
||||
teacher_logits,
|
||||
student_f,
|
||||
teacher_f,
|
||||
targets,
|
||||
config.KD_alpha,
|
||||
config.KD_temperature,
|
||||
)
|
||||
if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0:
|
||||
loss_aux = criterion(logits_aux, targets)
|
||||
@@ -139,7 +193,12 @@ def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode,
|
||||
)
|
||||
logger.log(
|
||||
" **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
|
||||
mode=mode.upper(), top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg
|
||||
mode=mode.upper(),
|
||||
top1=top1,
|
||||
top5=top5,
|
||||
error1=100 - top1.avg,
|
||||
error5=100 - top5.avg,
|
||||
loss=losses.avg,
|
||||
)
|
||||
)
|
||||
return losses.avg, top1.avg, top5.avg
|
||||
|
Reference in New Issue
Block a user