Update LFNA
This commit is contained in:
@@ -94,8 +94,10 @@ def epoch_evaluate(loader, meta_model, base_model, criterion, device, logger):
|
||||
|
||||
|
||||
def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
|
||||
base_model.train()
|
||||
meta_model.train()
|
||||
optimizer = torch.optim.Adam(
|
||||
meta_model.parameters(),
|
||||
meta_model.get_parameters(True, True, True),
|
||||
lr=args.lr,
|
||||
weight_decay=args.weight_decay,
|
||||
amsgrad=True,
|
||||
@@ -103,13 +105,16 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
|
||||
logger.log("Pre-train the meta-model")
|
||||
logger.log("Using the optimizer: {:}".format(optimizer))
|
||||
|
||||
meta_model.set_best_dir(logger.path(None) / "checkpoint-pretrain")
|
||||
meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v2")
|
||||
meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed))
|
||||
last_success_epoch = 0
|
||||
per_epoch_time, start_time = AverageMeter(), time.time()
|
||||
for iepoch in range(args.epochs):
|
||||
left_time = "Time Left: {:}".format(
|
||||
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
|
||||
)
|
||||
total_meta_losses, total_match_losses = [], []
|
||||
optimizer.zero_grad()
|
||||
for ibatch in range(args.meta_batch):
|
||||
rand_index = random.randint(0, meta_model.meta_length - xenv.seq_length - 1)
|
||||
timestamps = meta_model.meta_timestamps[
|
||||
@@ -118,7 +123,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
|
||||
|
||||
seq_timestamps, (seq_inputs, seq_targets) = xenv.seq_call(timestamps)
|
||||
[seq_containers], time_embeds = meta_model(
|
||||
torch.unsqueeze(timestamps, dim=0)
|
||||
torch.unsqueeze(timestamps, dim=0), None
|
||||
)
|
||||
# performance loss
|
||||
losses = []
|
||||
@@ -136,10 +141,10 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
|
||||
torch.squeeze(time_embeds, dim=0),
|
||||
meta_model.super_meta_embed[rand_index : rand_index + xenv.seq_length],
|
||||
)
|
||||
# batch_loss = meta_loss + match_loss * 0.1
|
||||
# total_losses.append(batch_loss)
|
||||
total_meta_losses.append(meta_loss)
|
||||
total_match_losses.append(match_loss)
|
||||
with torch.no_grad():
|
||||
meta_std = torch.stack(total_meta_losses).std().item()
|
||||
final_meta_loss = torch.stack(total_meta_losses).mean()
|
||||
final_match_loss = torch.stack(total_match_losses).mean()
|
||||
total_loss = final_meta_loss + final_match_loss
|
||||
@@ -148,11 +153,12 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
|
||||
# success
|
||||
success, best_score = meta_model.save_best(-total_loss.item())
|
||||
logger.log(
|
||||
"{:} [{:04d}/{:}] loss : {:.5f} = {:.5f} + {:.5f} (match)".format(
|
||||
"{:} [Pre-V2 {:04d}/{:}] loss : {:.5f} +- {:.5f} = {:.5f} + {:.5f} (match)".format(
|
||||
time_string(),
|
||||
iepoch,
|
||||
args.epochs,
|
||||
total_loss.item(),
|
||||
meta_std,
|
||||
final_meta_loss.item(),
|
||||
final_match_loss.item(),
|
||||
)
|
||||
@@ -160,11 +166,15 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
|
||||
+ ", success={:}, best_score={:.4f}".format(success, -best_score)
|
||||
+ " {:}".format(left_time)
|
||||
)
|
||||
if iepoch - last_success_epoch >= args.early_stop_thresh * 5:
|
||||
logger.log("Early stop the pre-training at {:}".format(iepoch))
|
||||
break
|
||||
per_epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
meta_model.load_best()
|
||||
|
||||
|
||||
def pretrain(base_model, meta_model, criterion, xenv, args, logger):
|
||||
def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger):
|
||||
base_model.train()
|
||||
meta_model.train()
|
||||
optimizer = torch.optim.Adam(
|
||||
@@ -173,12 +183,13 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger):
|
||||
weight_decay=args.weight_decay,
|
||||
amsgrad=True,
|
||||
)
|
||||
logger.log("Pre-train the meta-model")
|
||||
logger.log("Pre-train the meta-model's embeddings")
|
||||
logger.log("Using the optimizer: {:}".format(optimizer))
|
||||
|
||||
meta_model.set_best_dir(logger.path(None) / "ckps-basic-pretrain")
|
||||
meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v1")
|
||||
meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed))
|
||||
per_epoch_time, start_time = AverageMeter(), time.time()
|
||||
last_success_epoch = 0
|
||||
for iepoch in range(args.epochs):
|
||||
left_time = "Time Left: {:}".format(
|
||||
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
|
||||
@@ -213,7 +224,7 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger):
|
||||
# success
|
||||
success, best_score = meta_model.save_best(-final_loss.item())
|
||||
logger.log(
|
||||
"{:} [{:04d}/{:}] loss : {:.5f}".format(
|
||||
"{:} [Pre-V1 {:04d}/{:}] loss : {:.5f}".format(
|
||||
time_string(),
|
||||
iepoch,
|
||||
args.epochs,
|
||||
@@ -223,8 +234,12 @@ def pretrain(base_model, meta_model, criterion, xenv, args, logger):
|
||||
+ ", success={:}, best_score={:.4f}".format(success, -best_score)
|
||||
+ " {:}".format(left_time)
|
||||
)
|
||||
if iepoch - last_success_epoch >= args.early_stop_thresh * 5:
|
||||
logger.log("Early stop the pre-training at {:}".format(iepoch))
|
||||
break
|
||||
per_epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
meta_model.load_best()
|
||||
|
||||
|
||||
def main(args):
|
||||
@@ -282,7 +297,7 @@ def main(args):
|
||||
logger.log("The scheduler is\n{:}".format(lr_scheduler))
|
||||
logger.log("Per epoch iterations = {:}".format(len(train_env_loader)))
|
||||
|
||||
pretrain(base_model, meta_model, criterion, train_env, args, logger)
|
||||
pretrain_v2(base_model, meta_model, criterion, train_env, args, logger)
|
||||
|
||||
if logger.path("model").exists():
|
||||
ckp_data = torch.load(logger.path("model"))
|
||||
|
@@ -20,7 +20,7 @@ class LFNA_Meta(super_core.SuperModule):
|
||||
layer_embedding,
|
||||
time_embedding,
|
||||
meta_timestamps,
|
||||
mha_depth: int = 1,
|
||||
mha_depth: int = 2,
|
||||
dropout: float = 0.1,
|
||||
):
|
||||
super(LFNA_Meta, self).__init__()
|
||||
@@ -73,7 +73,7 @@ class LFNA_Meta(super_core.SuperModule):
|
||||
)
|
||||
)
|
||||
layers.append(super_core.SuperLinear(time_embedding * 2, time_embedding))
|
||||
self.meta_corrector = super_core.SuperSequential(*layers)
|
||||
self._meta_corrector = super_core.SuperSequential(*layers)
|
||||
|
||||
model_kwargs = dict(
|
||||
config=dict(model_type="dual_norm_mlp"),
|
||||
@@ -92,6 +92,18 @@ class LFNA_Meta(super_core.SuperModule):
|
||||
std=0.02,
|
||||
)
|
||||
|
||||
def get_parameters(self, time_embed, meta_corrector, generator):
|
||||
parameters = []
|
||||
if time_embed:
|
||||
parameters.append(self._super_meta_embed)
|
||||
if meta_corrector:
|
||||
parameters.extend(list(self._trans_att.parameters()))
|
||||
parameters.extend(list(self._meta_corrector.parameters()))
|
||||
if generator:
|
||||
parameters.append(self._super_layer_embed)
|
||||
parameters.extend(list(self._generator.parameters()))
|
||||
return parameters
|
||||
|
||||
@property
|
||||
def meta_timestamps(self):
|
||||
with torch.no_grad():
|
||||
@@ -159,7 +171,7 @@ class LFNA_Meta(super_core.SuperModule):
|
||||
# relative_timestamps = timestamps - timestamps[:, :1]
|
||||
# relative_pos_embeds = self._tscalar_embed(relative_timestamps)
|
||||
init_timestamp_embeds = torch.cat((timestamp_q_embed, timestamp_embeds), dim=-1)
|
||||
corrected_embeds = self.meta_corrector(init_timestamp_embeds)
|
||||
corrected_embeds = self._meta_corrector(init_timestamp_embeds)
|
||||
return corrected_embeds
|
||||
|
||||
def forward_raw(self, timestamps, time_embed):
|
||||
|
Reference in New Issue
Block a user