Update ablation for GeMOSA

This commit is contained in:
D-X-Y
2021-05-27 19:47:08 +08:00
parent 08337138f1
commit 5dd75696c9
4 changed files with 304 additions and 16 deletions

View File

@@ -4,6 +4,7 @@
# python exps/GeMOSA/basic-same.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda
# python exps/GeMOSA/basic-same.py --env_version v2 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda
# python exps/GeMOSA/basic-same.py --env_version v3 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda
# python exps/GeMOSA/basic-same.py --env_version v4 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
@@ -28,7 +29,12 @@ from xautodl.log_utils import AverageMeter, convert_secs2time
from xautodl.utils import split_str2indexes
from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn
from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
from xautodl.procedures.metric_utils import (
SaveMetric,
MSEMetric,
Top1AccMetric,
ComposeMetric,
)
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.models.xcore import get_model
@@ -57,6 +63,17 @@ def main(args):
logger.log("The total enviornment: {:}".format(env))
w_containers = dict()
if env.meta_info["task"] == "regression":
criterion = torch.nn.MSELoss()
metric_cls = MSEMetric
elif env.meta_info["task"] == "classification":
criterion = torch.nn.CrossEntropyLoss()
metric_cls = Top1AccMetric
else:
raise ValueError(
"This task ({:}) is not supported.".format(all_env.meta_info["task"])
)
per_timestamp_time, start_time = AverageMeter(), time.time()
for idx, (future_time, (future_x, future_y)) in enumerate(env):
@@ -79,7 +96,6 @@ def main(args):
print(model)
# build optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
criterion = torch.nn.MSELoss()
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
@@ -89,7 +105,7 @@ def main(args):
],
gamma=0.3,
)
train_metric = MSEMetric()
train_metric = metric_cls(True)
best_loss, best_param = None, None
for _iepoch in range(args.epochs):
preds = model(historical_x)
@@ -108,19 +124,19 @@ def main(args):
train_metric(preds, historical_y)
train_results = train_metric.get_info()
metric = ComposeMetric(MSEMetric(), SaveMetric())
xmetric = ComposeMetric(metric_cls(True), SaveMetric())
eval_dataset = torch.utils.data.TensorDataset(
future_x.to(args.device), future_y.to(args.device)
)
eval_loader = torch.utils.data.DataLoader(
eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0
)
results = basic_eval_fn(eval_loader, model, metric, logger)
results = basic_eval_fn(eval_loader, model, xmetric, logger)
log_str = (
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, len(env))
+ " train-mse: {:.5f}, eval-mse: {:.5f}".format(
train_results["mse"], results["mse"]
+ " train-score: {:.5f}, eval-score: {:.5f}".format(
train_results["score"], results["score"]
)
)
logger.log(log_str)