Update LFNA

This commit is contained in:
D-X-Y
2021-05-22 11:02:29 +00:00
parent ec241e4d69
commit 5b09f059fd
4 changed files with 46 additions and 19 deletions

View File

@@ -94,8 +94,10 @@ def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger):
def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
base_model.train()
meta_model.train()
optimizer = torch.optim.Adam(
meta_model.parameters(),
meta_model.get_parameters(True, True, True),
lr=args.lr,
weight_decay=args.weight_decay,
amsgrad=True,
@@ -103,13 +105,16 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
logger.log("Pre-train the meta-model")
logger.log("Using the optimizer: {:}".format(optimizer))
meta_model.set_best_dir(logger.path(None) / "checkpoint-pretrain")
meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v2")
meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed))
last_success_epoch = 0
per_epoch_time, start_time = AverageMeter(), time.time()
for iepoch in range(args.epochs):
left_time = "Time Left: {:}".format(
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
)
total_meta_losses, total_match_losses = [], []
optimizer.zero_grad()
for ibatch in range(args.meta_batch):
rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1)
timestamps = meta_model.meta_timestamps[
@@ -118,7 +123,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
seq_timestamps, (seq_inputs, seq_targets) = xenv.seq_call(timestamps)
[seq_containers], time_embeds = meta_model(
torch.unsqueeze(timestamps, dim=0)
torch.unsqueeze(timestamps, dim=0), None
)
# performance loss
losses = []
@@ -136,10 +141,10 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
torch.squeeze(time_embeds, dim=0),
meta_model.super_meta_embed[rand_index : rand_index + xenv.seq_length],
)
# batch_loss = meta_loss + match_loss * 0.1
# total_losses.append(batch_loss)
total_meta_losses.append(meta_loss)
total_match_losses.append(match_loss)
with torch.no_grad():
meta_std = torch.stack(total_meta_losses).std().item()
final_meta_loss = torch.stack(total_meta_losses).mean()
final_match_loss = torch.stack(total_match_losses).mean()
total_loss = final_meta_loss + final_match_loss
@@ -148,11 +153,12 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
# success
success, best_score = meta_model.save_best(-total_loss.item())
logger.log(
"{:} [{:04d}/{:}] loss : {:.5f} = {:.5f} + {:.5f} (match)".format(
"{:} [Pre-V2 {:04d}/{:}] loss : {:.5f} +- {:.5f} = {:.5f} + {:.5f} (match)".format(
time_string(),
iepoch,
args.epochs,
total_loss.item(),
meta_std,
final_meta_loss.item(),
final_match_loss.item(),
)
@@ -160,11 +166,15 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
+ ", success={:}, best_score={:.4f}".format(success, -best_score)
+ " {:}".format(left_time)
)
if iepoch - last_success_epoch >= args.early_stop_thresh * 5:
logger.log("Early stop the pre-training at {:}".format(iepoch))
break
per_epoch_time.update(time.time() - start_time)
start_time = time.time()
meta_model.load_best()
def pretrain(base_model, meta_model, criterion, xenv, args, logger):
def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger):
base_model.train()
meta_model.train()
optimizer = torch.optim.Adam(
@@ -173,12 +183,13 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger):
weight_decay=args.weight_decay,
amsgrad=True,
)
logger.log("Pre-train the meta-model")
logger.log("Pre-train the meta-model's embeddings")
logger.log("Using the optimizer: {:}".format(optimizer))
meta_model.set_best_dir(logger.path(None) / "ckps-basic-pretrain")
meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v1")
meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed))
per_epoch_time, start_time = AverageMeter(), time.time()
last_success_epoch = 0
for iepoch in range(args.epochs):
left_time = "Time Left: {:}".format(
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
@@ -213,7 +224,7 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger):
# success
success, best_score = meta_model.save_best(-final_loss.item())
logger.log(
"{:} [{:04d}/{:}] loss : {:.5f}".format(
"{:} [Pre-V1 {:04d}/{:}] loss : {:.5f}".format(
time_string(),
iepoch,
args.epochs,
@@ -223,8 +234,12 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger):
+ ", success={:}, best_score={:.4f}".format(success, -best_score)
+ " {:}".format(left_time)
)
if iepoch - last_success_epoch >= args.early_stop_thresh * 5:
logger.log("Early stop the pre-training at {:}".format(iepoch))
break
per_epoch_time.update(time.time() - start_time)
start_time = time.time()
meta_model.load_best()
def main(args):
@@ -282,7 +297,7 @@ def main(args):
logger.log("The scheduler is\n{:}".format(lr_scheduler))
logger.log("Per epoch iterations = {:}".format(len(train_env_loader)))
pretrain(base_model, meta_model, criterion, train_env, args, logger)
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
if logger.path("model").exists():
ckp_data = torch.load(logger.path("model"))

View File

@@ -20,7 +20,7 @@ class LFNA_Meta(super_core.SuperModule):
layer_embedding,
time_embedding,
meta_timestamps,
mha_depth: int = 1,
mha_depth: int = 2,
dropout: float = 0.1,
):
super(LFNA_Meta, self).__init__()
@@ -73,7 +73,7 @@ class LFNA_Meta(super_core.SuperModule):
)
)
layers.append(super_core.SuperLinear(time_embedding * 2, time_embedding))
self.meta_corrector = super_core.SuperSequential(*layers)
self._meta_corrector = super_core.SuperSequential(*layers)
model_kwargs = dict(
config=dict(model_type="dual_norm_mlp"),
@@ -92,6 +92,18 @@ class LFNA_Meta(super_core.SuperModule):
std=0.02,
)
def get_parameters(self, time_embed, meta_corrector, generator):
parameters = []
if time_embed:
parameters.append(self._super_meta_embed)
if meta_corrector:
parameters.extend(list(self._trans_att.parameters()))
parameters.extend(list(self._meta_corrector.parameters()))
if generator:
parameters.append(self._super_layer_embed)
parameters.extend(list(self._generator.parameters()))
return parameters
@property
def meta_timestamps(self):
with torch.no_grad():
@@ -159,7 +171,7 @@ class LFNA_Meta(super_core.SuperModule):
# relative_timestamps = timestamps - timestamps[:, :1]
# relative_pos_embeds = self._tscalar_embed(relative_timestamps)
init_timestamp_embeds = torch.cat((timestamp_q_embed, timestamp_embeds), dim=-1)
corrected_embeds = self.meta_corrector(init_timestamp_embeds)
corrected_embeds = self._meta_corrector(init_timestamp_embeds)
return corrected_embeds
def forward_raw(self, timestamps, time_embed):