Update xlayers

This commit is contained in:
D-X-Y
2021-05-22 23:04:24 +08:00
parent 5b09f059fd
commit 8109ed166a
6 changed files with 104 additions and 33 deletions

View File

@@ -106,8 +106,13 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
logger.log("Using the optimizer: {:}".format(optimizer))
meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v2")
final_best_name = "final-pretrain-{:}.pth".format(args.rand_seed)
if meta_model.has_best(final_best_name):
meta_model.load_best(final_best_name)
return
meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed))
last_success_epoch = 0
last_success_epoch, early_stop_thresh = 0, args.pretrain_early_stop_thresh
per_epoch_time, start_time = AverageMeter(), time.time()
for iepoch in range(args.epochs):
left_time = "Time Left: {:}".format(
@@ -164,14 +169,21 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
)
+ ", batch={:}".format(len(total_meta_losses))
+ ", success={:}, best_score={:.4f}".format(success, -best_score)
+ " {:}".format(left_time)
+ ", LS={:}/{:}".format(last_success_epoch, early_stop_thresh)
+ ", {:}".format(left_time)
)
if iepoch - last_success_epoch >= args.early_stop_thresh * 5:
if success:
last_success_epoch = iepoch
if iepoch - last_success_epoch >= early_stop_thresh:
logger.log("Early stop the pre-training at {:}".format(iepoch))
break
per_epoch_time.update(time.time() - start_time)
start_time = time.time()
meta_model.load_best()
# save to the final model
meta_model.set_best_name(final_best_name)
success, _ = meta_model.save_best(best_score + 1e-6)
assert success
def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger):
@@ -189,7 +201,7 @@ def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger):
meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v1")
meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed))
per_epoch_time, start_time = AverageMeter(), time.time()
last_success_epoch = 0
last_success_epoch, early_stop_thresh = 0, args.pretrain_early_stop_thresh
for iepoch in range(args.epochs):
left_time = "Time Left: {:}".format(
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
@@ -232,9 +244,12 @@ def pretrain_v1(base_model, meta_model, criterion, xenv, args, logger):
)
+ ", batch={:}".format(len(losses))
+ ", success={:}, best_score={:.4f}".format(success, -best_score)
+ ", LS={:}/{:}".format(last_success_epoch, early_stop_thresh)
+ " {:}".format(left_time)
)
if iepoch - last_success_epoch >= args.early_stop_thresh * 5:
if success:
last_success_epoch = iepoch
if iepoch - last_success_epoch >= early_stop_thresh:
logger.log("Early stop the pre-training at {:}".format(iepoch))
break
per_epoch_time.update(time.time() - start_time)
@@ -521,7 +536,7 @@ if __name__ == "__main__":
parser.add_argument(
"--refine_lr",
type=float,
default=0.005,
default=0.001,
help="The learning rate for the optimizer, during refine",
)
parser.add_argument(
@@ -533,6 +548,12 @@ if __name__ == "__main__":
default=20,
help="The #epochs for early stop.",
)
parser.add_argument(
"--pretrain_early_stop_thresh",
type=int,
default=200,
help="The #epochs for early stop.",
)
parser.add_argument(
"--seq_length", type=int, default=10, help="The sequence length."
)