Re-org GeMOSA codes
This commit is contained in:
@@ -35,7 +35,7 @@ from xautodl.models.xcore import get_model
|
||||
from xautodl.xlayers import super_core, trunc_normal_
|
||||
|
||||
from lfna_utils import lfna_setup, train_model, TimeData
|
||||
from lfna_meta_model import MetaModelV1
|
||||
from meta_model import MetaModelV1
|
||||
|
||||
|
||||
def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False):
|
||||
@@ -106,7 +106,7 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger):
|
||||
)
|
||||
optimizer.zero_grad()
|
||||
|
||||
generated_time_embeds = meta_model(meta_model.meta_timestamps, None, True)
|
||||
generated_time_embeds = gen_time_embed(meta_model.meta_timestamps)
|
||||
|
||||
batch_indexes = random.choices(total_indexes, k=args.meta_batch)
|
||||
|
||||
@@ -219,11 +219,11 @@ def main(args):
|
||||
w_containers, loss_meter = online_evaluate(
|
||||
all_env, meta_model, base_model, criterion, args, logger, True
|
||||
)
|
||||
logger.log("In this enviornment, the loss-meter is {:}".format(loss_meter))
|
||||
logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter))
|
||||
|
||||
save_checkpoint(
|
||||
{"w_containers": w_containers},
|
||||
logger.path(None) / "final-ckp.pth",
|
||||
{"all_w_containers": w_containers},
|
||||
logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed),
|
||||
logger,
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user