Update GeMOSA v4
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
# python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda
|
||||
# python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda
|
||||
# python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda
|
||||
# python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda
|
||||
#####################################################
|
||||
import sys, time, copy, torch, random, argparse
|
||||
from tqdm import tqdm
|
||||
@@ -32,15 +33,24 @@ from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn
|
||||
from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
|
||||
from xautodl.datasets.synthetic_core import get_synthetic_env
|
||||
from xautodl.models.xcore import get_model
|
||||
from xautodl.xlayers import super_core, trunc_normal_
|
||||
from xautodl.procedures.metric_utils import MSEMetric, Top1AccMetric
|
||||
|
||||
from meta_model import MetaModelV1
|
||||
|
||||
|
||||
def online_evaluate(
|
||||
env, meta_model, base_model, criterion, args, logger, save=False, easy_adapt=False
|
||||
env,
|
||||
meta_model,
|
||||
base_model,
|
||||
criterion,
|
||||
metric,
|
||||
args,
|
||||
logger,
|
||||
save=False,
|
||||
easy_adapt=False,
|
||||
):
|
||||
logger.log("Online evaluate: {:}".format(env))
|
||||
metric.reset()
|
||||
loss_meter = AverageMeter()
|
||||
w_containers = dict()
|
||||
for idx, (future_time, (future_x, future_y)) in enumerate(env):
|
||||
@@ -57,6 +67,8 @@ def online_evaluate(
|
||||
future_y_hat = base_model.forward_with_container(future_x, future_container)
|
||||
future_loss = criterion(future_y_hat, future_y)
|
||||
loss_meter.update(future_loss.item())
|
||||
# accumulate the metric scores
|
||||
metric(future_y_hat, future_y)
|
||||
if easy_adapt:
|
||||
meta_model.easy_adapt(future_time.item(), future_time_embed)
|
||||
refine, post_refine_loss = False, -1
|
||||
@@ -79,7 +91,7 @@ def online_evaluate(
|
||||
)
|
||||
meta_model.clear_fixed()
|
||||
meta_model.clear_learnt()
|
||||
return w_containers, loss_meter
|
||||
return w_containers, loss_meter.avg, metric.get_info()["score"]
|
||||
|
||||
|
||||
def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger):
|
||||
@@ -203,7 +215,16 @@ def main(args):
|
||||
|
||||
base_model = get_model(**model_kwargs)
|
||||
base_model = base_model.to(args.device)
|
||||
criterion = torch.nn.MSELoss()
|
||||
if all_env.meta_info["task"] == "regression":
|
||||
criterion = torch.nn.MSELoss()
|
||||
metric = MSEMetric(True)
|
||||
elif all_env.meta_info["task"] == "classification":
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
metric = Top1AccMetric(True)
|
||||
else:
|
||||
raise ValueError(
|
||||
"This task ({:}) is not supported.".format(all_env.meta_info["task"])
|
||||
)
|
||||
|
||||
shape_container = base_model.get_w_container().to_shape_container()
|
||||
|
||||
@@ -235,27 +256,29 @@ def main(args):
|
||||
)
|
||||
logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter))
|
||||
"""
|
||||
_, test_loss_meter_adapt_v1 = online_evaluate(
|
||||
valid_env, meta_model, base_model, criterion, args, logger, False, False
|
||||
_, loss_adapt_v1, metric_adapt_v1 = online_evaluate(
|
||||
valid_env, meta_model, base_model, criterion, metric, args, logger, False, False
|
||||
)
|
||||
_, test_loss_meter_adapt_v2 = online_evaluate(
|
||||
valid_env, meta_model, base_model, criterion, args, logger, False, True
|
||||
_, loss_adapt_v2, metric_adapt_v2 = online_evaluate(
|
||||
valid_env, meta_model, base_model, criterion, metric, args, logger, False, True
|
||||
)
|
||||
logger.log(
|
||||
"In the online test enviornment, the total loss for refine-adapt is {:}".format(
|
||||
test_loss_meter_adapt_v1
|
||||
"[Refine-Adapt] loss = {:.6f}, metric = {:.6f}".format(
|
||||
loss_adapt_v1, metric_adapt_v1
|
||||
)
|
||||
)
|
||||
logger.log(
|
||||
"In the online test enviornment, the total loss for easy-adapt is {:}".format(
|
||||
test_loss_meter_adapt_v2
|
||||
"[Easy-Adapt] loss = {:.6f}, metric = {:.6f}".format(
|
||||
loss_adapt_v2, metric_adapt_v2
|
||||
)
|
||||
)
|
||||
|
||||
save_checkpoint(
|
||||
{
|
||||
"test_loss_adapt_v1": test_loss_meter_adapt_v1.avg,
|
||||
"test_loss_adapt_v2": test_loss_meter_adapt_v2.avg,
|
||||
"test_loss_adapt_v1": loss_adapt_v1,
|
||||
"test_loss_adapt_v2": loss_adapt_v2,
|
||||
"test_metric_adapt_v1": metric_adapt_v1,
|
||||
"test_metric_adapt_v2": metric_adapt_v2,
|
||||
},
|
||||
logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed),
|
||||
logger,
|
||||
|
Reference in New Issue
Block a user