Update xlayers
This commit is contained in:
@@ -106,8 +106,13 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
|
||||
logger.log("Using the optimizer: {:}".format(optimizer))
|
||||
|
||||
meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v2")
|
||||
final_best_name = "final-pretrain-{:}.pth".format(args.rand_seed)
|
||||
if meta_model.has_best(final_best_name):
|
||||
meta_model.load_best(final_best_name)
|
||||
return
|
||||
|
||||
meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed))
|
||||
last_success_epoch = 0
|
||||
last_success_epoch, early_stop_thresh = 0, args.pretrain_early_stop_thresh
|
||||
per_epoch_time, start_time = AverageMeter(), time.time()
|
||||
for iepoch in range(args.epochs):
|
||||
left_time = "Time Left: {:}".format(
|
||||
@@ -164,14 +169,21 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
|
||||
)
|
||||
+ ", batch={:}".format(len(total_meta_losses))
|
||||
+ ", success={:}, best_score={:.4f}".format(success, -best_score)
|
||||
+ " {:}".format(left_time)
|
||||
+ ", LS={:}/{:}".format(last_success_epoch, early_stop_thresh)
|
||||
+ ", {:}".format(left_time)
|
||||
)
|
||||
if iepoch - last_success_epoch >= args.early_stop_thresh * 5:
|
||||
if success:
|
||||
last_success_epoch = iepoch
|
||||
if iepoch - last_success_epoch >= early_stop_thresh:
|
||||
logger.log("Early stop the pre-training at {:}".format(iepoch))
|
||||
break
|
||||
per_epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
meta_model.load_best()
|
||||
# save to the final model
|
||||
meta_model.set_best_name(final_best_name)
|
||||
success, _ = meta_model.save_best(best_score + 1e-6)
|
||||
assert success
|
||||
|
||||
|
||||
def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger):
|
||||
@@ -189,7 +201,7 @@ def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger):
|
||||
meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v1")
|
||||
meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed))
|
||||
per_epoch_time, start_time = AverageMeter(), time.time()
|
||||
last_success_epoch = 0
|
||||
last_success_epoch, early_stop_thresh = 0, args.pretrain_early_stop_thresh
|
||||
for iepoch in range(args.epochs):
|
||||
left_time = "Time Left: {:}".format(
|
||||
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
|
||||
@@ -232,9 +244,12 @@ def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger):
|
||||
)
|
||||
+ ", batch={:}".format(len(losses))
|
||||
+ ", success={:}, best_score={:.4f}".format(success, -best_score)
|
||||
+ ", LS={:}/{:}".format(last_success_epoch, early_stop_thresh)
|
||||
+ " {:}".format(left_time)
|
||||
)
|
||||
if iepoch - last_success_epoch >= args.early_stop_thresh * 5:
|
||||
if success:
|
||||
last_success_epoch = iepoch
|
||||
if iepoch - last_success_epoch >= early_stop_thresh:
|
||||
logger.log("Early stop the pre-training at {:}".format(iepoch))
|
||||
break
|
||||
per_epoch_time.update(time.time() - start_time)
|
||||
@@ -521,7 +536,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--refine_lr",
|
||||
type=float,
|
||||
default=0.005,
|
||||
default=0.001,
|
||||
help="The learning rate for the optimizer, during refine",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -533,6 +548,12 @@ if __name__ == "__main__":
|
||||
default=20,
|
||||
help="The #epochs for early stop.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrain_early_stop_thresh",
|
||||
type=int,
|
||||
default=200,
|
||||
help="The #epochs for early stop.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seq_length", type=int, default=10, help="The sequence length."
|
||||
)
|
||||
|
@@ -70,6 +70,7 @@ class LFNA_Meta(super_core.SuperModule):
|
||||
dropout,
|
||||
norm_affine=False,
|
||||
order=super_core.LayerOrder.PostNorm,
|
||||
use_mask=True,
|
||||
)
|
||||
)
|
||||
layers.append(super_core.SuperLinear(time_embedding * 2, time_embedding))
|
||||
@@ -162,11 +163,14 @@ class LFNA_Meta(super_core.SuperModule):
|
||||
def _obtain_time_embed(self, timestamps):
|
||||
# timestamps is a batch of sequence of timestamps
|
||||
batch, seq = timestamps.shape
|
||||
meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed
|
||||
timestamp_q_embed = self._tscalar_embed(timestamps)
|
||||
timestamp_k_embed = self._tscalar_embed(self.meta_timestamps.view(1, -1))
|
||||
timestamp_v_embed = self.super_meta_embed.unsqueeze(dim=0)
|
||||
timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1))
|
||||
timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
|
||||
# create the mask
|
||||
mask = torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1)
|
||||
timestamp_embeds = self._trans_att(
|
||||
timestamp_q_embed, timestamp_k_embed, timestamp_v_embed
|
||||
timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask
|
||||
)
|
||||
# relative_timestamps = timestamps - timestamps[:, :1]
|
||||
# relative_pos_embeds = self._tscalar_embed(relative_timestamps)
|
||||
@@ -186,8 +190,12 @@ class LFNA_Meta(super_core.SuperModule):
|
||||
layer_embed = self._super_layer_embed.view(1, 1, num_layer, -1).expand(
|
||||
batch, seq, -1, -1
|
||||
)
|
||||
joint_embed = torch.cat((meta_embed, layer_embed), dim=-1)
|
||||
batch_weights = self._generator(joint_embed)
|
||||
joint_embed = torch.cat(
|
||||
(meta_embed, layer_embed), dim=-1
|
||||
) # batch, seq, num-layers, input-dim
|
||||
batch_weights = self._generator(
|
||||
joint_embed
|
||||
) # batch, seq, num-layers, num-weights
|
||||
batch_containers = []
|
||||
for seq_weights in torch.split(batch_weights, 1):
|
||||
seq_containers = []
|
||||
|
Reference in New Issue
Block a user