Update ablation for GeMOSA
This commit is contained in:
@@ -1,12 +1,16 @@
|
||||
#####################################################
|
||||
# Learning to Generate Model One Step Ahead #
|
||||
#####################################################
|
||||
##########################################################
|
||||
# Learning to Efficiently Generate Models One Step Ahead #
|
||||
##########################################################
|
||||
# <----> run on CPU
|
||||
# python exps/GeMOSA/main.py --env_version v1 --workers 0
|
||||
# <----> run on a GPU
|
||||
# python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda
|
||||
# python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda
|
||||
# python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda
|
||||
# python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda
|
||||
#####################################################
|
||||
# <----> ablation commands
|
||||
# python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --ablation old --device cuda
|
||||
##########################################################
|
||||
import sys, time, copy, torch, random, argparse
|
||||
from tqdm import tqdm
|
||||
from copy import deepcopy
|
||||
@@ -36,6 +40,7 @@ from xautodl.models.xcore import get_model
|
||||
from xautodl.procedures.metric_utils import MSEMetric, Top1AccMetric
|
||||
|
||||
from meta_model import MetaModelV1
|
||||
from meta_model_ablation import MetaModel_TraditionalAtt
|
||||
|
||||
|
||||
def online_evaluate(
|
||||
@@ -230,7 +235,13 @@ def main(args):
|
||||
|
||||
# pre-train the hypernetwork
|
||||
timestamps = trainval_env.get_timestamp(None)
|
||||
meta_model = MetaModelV1(
|
||||
if args.ablation is None:
|
||||
MetaModel_cls = MetaModelV1
|
||||
elif args.ablation == "old":
|
||||
MetaModel_cls = MetaModel_TraditionalAtt
|
||||
else:
|
||||
raise ValueError("Unknown ablation : {:}".format(args.ablation))
|
||||
meta_model = MetaModel_cls(
|
||||
shape_container,
|
||||
args.layer_dim,
|
||||
args.time_dim,
|
||||
@@ -373,6 +384,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--workers", type=int, default=4, help="The number of workers in parallel."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ablation", type=str, default=None, help="The ablation indicator."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
@@ -385,7 +399,7 @@ if __name__ == "__main__":
|
||||
if args.rand_seed is None or args.rand_seed < 0:
|
||||
args.rand_seed = random.randint(1, 100000)
|
||||
assert args.save_dir is not None, "The save dir argument can not be None"
|
||||
args.save_dir = "{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-env{:}".format(
|
||||
args.save_dir = "{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-ab{:}-env{:}".format(
|
||||
args.save_dir,
|
||||
args.meta_batch,
|
||||
args.hidden_dim,
|
||||
@@ -395,6 +409,7 @@ if __name__ == "__main__":
|
||||
args.lr,
|
||||
args.weight_decay,
|
||||
args.epochs,
|
||||
args.ablation,
|
||||
args.env_version,
|
||||
)
|
||||
main(args)
|
||||
|
Reference in New Issue
Block a user