Re-org GeMOSA codes

This commit is contained in:
D-X-Y
2021-05-27 11:17:57 +08:00
parent a507f8dd94
commit 8961215416
8 changed files with 82 additions and 162 deletions

View File

@@ -35,7 +35,7 @@ from xautodl.models.xcore import get_model
from xautodl.xlayers import super_core, trunc_normal_
from lfna_utils import lfna_setup, train_model, TimeData
from lfna_meta_model import MetaModelV1
from meta_model import MetaModelV1
def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False):
@@ -106,7 +106,7 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger):
)
optimizer.zero_grad()
generated_time_embeds = meta_model(meta_model.meta_timestamps, None, True)
generated_time_embeds = gen_time_embed(meta_model.meta_timestamps)
batch_indexes = random.choices(total_indexes, k=args.meta_batch)
@@ -219,11 +219,11 @@ def main(args):
w_containers, loss_meter = online_evaluate(
all_env, meta_model, base_model, criterion, args, logger, True
)
logger.log("In this enviornment, the loss-meter is {:}".format(loss_meter))
logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter))
save_checkpoint(
{"w_containers": w_containers},
logger.path(None) / "final-ckp.pth",
{"all_w_containers": w_containers},
logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed),
logger,
)