Update ablation for GeMOSA
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
# python exps/GeMOSA/basic-same.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda
|
||||
# python exps/GeMOSA/basic-same.py --env_version v2 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda
|
||||
# python exps/GeMOSA/basic-same.py --env_version v3 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda
|
||||
# python exps/GeMOSA/basic-same.py --env_version v4 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda
|
||||
#####################################################
|
||||
import sys, time, copy, torch, random, argparse
|
||||
from tqdm import tqdm
|
||||
@@ -28,7 +29,12 @@ from xautodl.log_utils import AverageMeter, convert_secs2time
|
||||
from xautodl.utils import split_str2indexes
|
||||
|
||||
from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn
|
||||
from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
|
||||
from xautodl.procedures.metric_utils import (
|
||||
SaveMetric,
|
||||
MSEMetric,
|
||||
Top1AccMetric,
|
||||
ComposeMetric,
|
||||
)
|
||||
from xautodl.datasets.synthetic_core import get_synthetic_env
|
||||
from xautodl.models.xcore import get_model
|
||||
|
||||
@@ -57,6 +63,17 @@ def main(args):
|
||||
logger.log("The total enviornment: {:}".format(env))
|
||||
w_containers = dict()
|
||||
|
||||
if env.meta_info["task"] == "regression":
|
||||
criterion = torch.nn.MSELoss()
|
||||
metric_cls = MSEMetric
|
||||
elif env.meta_info["task"] == "classification":
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
metric_cls = Top1AccMetric
|
||||
else:
|
||||
raise ValueError(
|
||||
"This task ({:}) is not supported.".format(all_env.meta_info["task"])
|
||||
)
|
||||
|
||||
per_timestamp_time, start_time = AverageMeter(), time.time()
|
||||
for idx, (future_time, (future_x, future_y)) in enumerate(env):
|
||||
|
||||
@@ -79,7 +96,6 @@ def main(args):
|
||||
print(model)
|
||||
# build optimizer
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
|
||||
criterion = torch.nn.MSELoss()
|
||||
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||
optimizer,
|
||||
milestones=[
|
||||
@@ -89,7 +105,7 @@ def main(args):
|
||||
],
|
||||
gamma=0.3,
|
||||
)
|
||||
train_metric = MSEMetric()
|
||||
train_metric = metric_cls(True)
|
||||
best_loss, best_param = None, None
|
||||
for _iepoch in range(args.epochs):
|
||||
preds = model(historical_x)
|
||||
@@ -108,19 +124,19 @@ def main(args):
|
||||
train_metric(preds, historical_y)
|
||||
train_results = train_metric.get_info()
|
||||
|
||||
metric = ComposeMetric(MSEMetric(), SaveMetric())
|
||||
xmetric = ComposeMetric(metric_cls(True), SaveMetric())
|
||||
eval_dataset = torch.utils.data.TensorDataset(
|
||||
future_x.to(args.device), future_y.to(args.device)
|
||||
)
|
||||
eval_loader = torch.utils.data.DataLoader(
|
||||
eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0
|
||||
)
|
||||
results = basic_eval_fn(eval_loader, model, metric, logger)
|
||||
results = basic_eval_fn(eval_loader, model, xmetric, logger)
|
||||
log_str = (
|
||||
"[{:}]".format(time_string())
|
||||
+ " [{:04d}/{:04d}]".format(idx, len(env))
|
||||
+ " train-mse: {:.5f}, eval-mse: {:.5f}".format(
|
||||
train_results["mse"], results["mse"]
|
||||
+ " train-score: {:.5f}, eval-score: {:.5f}".format(
|
||||
train_results["score"], results["score"]
|
||||
)
|
||||
)
|
||||
logger.log(log_str)
|
||||
|
Reference in New Issue
Block a user