Update ablation for GeMOSA

This commit is contained in:
D-X-Y
2021-05-27 19:47:08 +08:00
parent 08337138f1
commit 5dd75696c9
4 changed files with 304 additions and 16 deletions

View File

@@ -1,12 +1,16 @@
#####################################################
# Learning to Generate Model One Step Ahead #
#####################################################
##########################################################
# Learning to Efficiently Generate Models One Step Ahead #
##########################################################
# <----> run on CPU
# python exps/GeMOSA/main.py --env_version v1 --workers 0
# <----> run on a GPU
# python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda
# python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda
# python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda
# python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda
#####################################################
# <----> ablation commands
# python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --ablation old --device cuda
##########################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
@@ -36,6 +40,7 @@ from xautodl.models.xcore import get_model
from xautodl.procedures.metric_utils import MSEMetric, Top1AccMetric
from meta_model import MetaModelV1
from meta_model_ablation import MetaModel_TraditionalAtt
def online_evaluate(
@@ -230,7 +235,13 @@ def main(args):
# pre-train the hypernetwork
timestamps = trainval_env.get_timestamp(None)
meta_model = MetaModelV1(
if args.ablation is None:
MetaModel_cls = MetaModelV1
elif args.ablation == "old":
MetaModel_cls = MetaModel_TraditionalAtt
else:
raise ValueError("Unknown ablation : {:}".format(args.ablation))
meta_model = MetaModel_cls(
shape_container,
args.layer_dim,
args.time_dim,
@@ -373,6 +384,9 @@ if __name__ == "__main__":
parser.add_argument(
"--workers", type=int, default=4, help="The number of workers in parallel."
)
parser.add_argument(
"--ablation", type=str, default=None, help="The ablation indicator."
)
parser.add_argument(
"--device",
type=str,
@@ -385,7 +399,7 @@ if __name__ == "__main__":
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-env{:}".format(
args.save_dir = "{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-ab{:}-env{:}".format(
args.save_dir,
args.meta_batch,
args.hidden_dim,
@@ -395,6 +409,7 @@ if __name__ == "__main__":
args.lr,
args.weight_decay,
args.epochs,
args.ablation,
args.env_version,
)
main(args)