add autodl

This commit is contained in:
mhz
2024-08-25 18:02:31 +02:00
parent 192f286cfb
commit a0a25f291c
431 changed files with 50646 additions and 8 deletions

View File

@@ -0,0 +1,319 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/GeMOSA/baselines/maml-ft.py --env_version v1 --hidden_dim 16 --inner_step 5 --device cuda
# python exps/GeMOSA/baselines/maml-ft.py --env_version v2 --hidden_dim 16 --inner_step 5 --device cuda
# python exps/GeMOSA/baselines/maml-ft.py --env_version v3 --hidden_dim 32 --inner_step 5 --device cuda
# python exps/GeMOSA/baselines/maml-ft.py --env_version v4 --hidden_dim 32 --inner_step 5 --device cuda
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "..").resolve()
print(lib_dir)
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
)
from xautodl.log_utils import time_string
from xautodl.log_utils import AverageMeter, convert_secs2time
from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, Top1AccMetric
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.models.xcore import get_model
from xautodl.xlayers import super_core
class MAML:
"""A LFNA meta-model that uses the MLP as delta-net."""
def __init__(
self, network, criterion, epochs, meta_lr, inner_lr=0.01, inner_step=1
):
self.criterion = criterion
self.network = network
self.meta_optimizer = torch.optim.Adam(
self.network.parameters(), lr=meta_lr, amsgrad=True
)
self.inner_lr = inner_lr
self.inner_step = inner_step
self._best_info = dict(state_dict=None, iepoch=None, score=None)
print("There are {:} weights.".format(self.network.get_w_container().numel()))
def adapt(self, x, y):
# create a container for the future timestamp
container = self.network.get_w_container()
for k in range(0, self.inner_step):
y_hat = self.network.forward_with_container(x, container)
loss = self.criterion(y_hat, y)
grads = torch.autograd.grad(loss, container.parameters())
container = container.additive([-self.inner_lr * grad for grad in grads])
return container
def predict(self, x, container=None):
if container is not None:
y_hat = self.network.forward_with_container(x, container)
else:
y_hat = self.network(x)
return y_hat
def step(self):
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0)
self.meta_optimizer.step()
def zero_grad(self):
self.meta_optimizer.zero_grad()
def load_state_dict(self, state_dict):
self.criterion.load_state_dict(state_dict["criterion"])
self.network.load_state_dict(state_dict["network"])
self.meta_optimizer.load_state_dict(state_dict["meta_optimizer"])
def state_dict(self):
state_dict = dict()
state_dict["criterion"] = self.criterion.state_dict()
state_dict["network"] = self.network.state_dict()
state_dict["meta_optimizer"] = self.meta_optimizer.state_dict()
return state_dict
def save_best(self, score):
success, best_score = self.network.save_best(score)
return success, best_score
def load_best(self):
self.network.load_best()
def main(args):
prepare_seed(args.rand_seed)
logger = prepare_logger(args)
train_env = get_synthetic_env(mode="train", version=args.env_version)
valid_env = get_synthetic_env(mode="valid", version=args.env_version)
trainval_env = get_synthetic_env(mode="trainval", version=args.env_version)
test_env = get_synthetic_env(mode="test", version=args.env_version)
all_env = get_synthetic_env(mode=None, version=args.env_version)
logger.log("The training enviornment: {:}".format(train_env))
logger.log("The validation enviornment: {:}".format(valid_env))
logger.log("The trainval enviornment: {:}".format(trainval_env))
logger.log("The total enviornment: {:}".format(all_env))
logger.log("The test enviornment: {:}".format(test_env))
model_kwargs = dict(
config=dict(model_type="norm_mlp"),
input_dim=all_env.meta_info["input_dim"],
output_dim=all_env.meta_info["output_dim"],
hidden_dims=[args.hidden_dim] * 2,
act_cls="relu",
norm_cls="layer_norm_1d",
)
model = get_model(**model_kwargs)
model = model.to(args.device)
if all_env.meta_info["task"] == "regression":
criterion = torch.nn.MSELoss()
metric_cls = MSEMetric
elif all_env.meta_info["task"] == "classification":
criterion = torch.nn.CrossEntropyLoss()
metric_cls = Top1AccMetric
else:
raise ValueError(
"This task ({:}) is not supported.".format(all_env.meta_info["task"])
)
maml = MAML(
model, criterion, args.epochs, args.meta_lr, args.inner_lr, args.inner_step
)
# meta-training
last_success_epoch = 0
per_epoch_time, start_time = AverageMeter(), time.time()
for iepoch in range(args.epochs):
need_time = "Time Left: {:}".format(
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
)
head_str = (
"[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs)
+ need_time
)
maml.zero_grad()
meta_losses = []
for ibatch in range(args.meta_batch):
future_idx = random.randint(0, len(trainval_env) - 1)
future_t, (future_x, future_y) = trainval_env[future_idx]
# -->>
seq_times = trainval_env.get_seq_times(future_idx, args.seq_length)
_, (allxs, allys) = trainval_env.seq_call(seq_times)
allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1)
if trainval_env.meta_info["task"] == "classification":
allys = allys.view(-1)
historical_x, historical_y = allxs.to(args.device), allys.to(args.device)
future_container = maml.adapt(historical_x, historical_y)
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = maml.predict(future_x, future_container)
future_loss = maml.criterion(future_y_hat, future_y)
meta_losses.append(future_loss)
meta_loss = torch.stack(meta_losses).mean()
meta_loss.backward()
maml.step()
logger.log(head_str + " meta-loss: {:.4f}".format(meta_loss.item()))
success, best_score = maml.save_best(-meta_loss.item())
if success:
logger.log("Achieve the best with best_score = {:.3f}".format(best_score))
save_checkpoint(maml.state_dict(), logger.path("model"), logger)
last_success_epoch = iepoch
if iepoch - last_success_epoch >= args.early_stop_thresh:
logger.log("Early stop at {:}".format(iepoch))
break
per_epoch_time.update(time.time() - start_time)
start_time = time.time()
# meta-test
maml.load_best()
def finetune(index):
seq_times = test_env.get_seq_times(index, args.seq_length)
_, (allxs, allys) = test_env.seq_call(seq_times)
allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1)
if test_env.meta_info["task"] == "classification":
allys = allys.view(-1)
historical_x, historical_y = allxs.to(args.device), allys.to(args.device)
future_container = maml.adapt(historical_x, historical_y)
historical_y_hat = maml.predict(historical_x, future_container)
train_metric = metric_cls(True)
# model.analyze_weights()
with torch.no_grad():
train_metric(historical_y_hat, historical_y)
train_results = train_metric.get_info()
return train_results, future_container
metric = metric_cls(True)
per_timestamp_time, start_time = AverageMeter(), time.time()
for idx, (future_time, (future_x, future_y)) in enumerate(test_env):
need_time = "Time Left: {:}".format(
convert_secs2time(per_timestamp_time.avg * (len(test_env) - idx), True)
)
logger.log(
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, len(test_env))
+ " "
+ need_time
)
# build optimizer
train_results, future_container = finetune(idx)
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = maml.predict(future_x, future_container)
future_loss = criterion(future_y_hat, future_y)
metric(future_y_hat, future_y)
log_str = (
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, len(test_env))
+ " train-score: {:.5f}, eval-score: {:.5f}".format(
train_results["score"], metric.get_info()["score"]
)
)
logger.log(log_str)
logger.log("")
per_timestamp_time.update(time.time() - start_time)
start_time = time.time()
logger.log("-" * 200 + "\n")
logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Use the maml.")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/GeMOSA-synthetic/use-maml-ft",
help="The checkpoint directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
parser.add_argument(
"--hidden_dim",
type=int,
default=16,
help="The hidden dimension.",
)
parser.add_argument(
"--meta_lr",
type=float,
default=0.02,
help="The learning rate for the MAML optimizer (default is Adam)",
)
parser.add_argument(
"--inner_lr",
type=float,
default=0.005,
help="The learning rate for the inner optimization",
)
parser.add_argument(
"--inner_step", type=int, default=1, help="The inner loop steps for MAML."
)
parser.add_argument(
"--seq_length", type=int, default=20, help="The sequence length."
)
parser.add_argument(
"--meta_batch",
type=int,
default=256,
help="The batch size for the meta-model",
)
parser.add_argument(
"--epochs",
type=int,
default=2000,
help="The total number of epochs.",
)
parser.add_argument(
"--early_stop_thresh",
type=int,
default=50,
help="The maximum epochs for early stop.",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="",
)
parser.add_argument(
"--workers",
type=int,
default=4,
help="The number of data loading workers (default: 4)",
)
# Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-s{:}-mlr{:}-d{:}-e{:}-env{:}".format(
args.save_dir,
args.inner_step,
args.meta_lr,
args.hidden_dim,
args.epochs,
args.env_version,
)
main(args)

View File

@@ -0,0 +1,319 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/GeMOSA/baselines/maml-nof.py --env_version v1 --hidden_dim 16 --inner_step 5 --device cuda
# python exps/GeMOSA/baselines/maml-nof.py --env_version v2 --hidden_dim 16 --inner_step 5 --device cuda
# python exps/GeMOSA/baselines/maml-nof.py --env_version v3 --hidden_dim 32 --inner_step 5 --device cuda
# python exps/GeMOSA/baselines/maml-nof.py --env_version v4 --hidden_dim 32 --inner_step 5 --device cuda
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "..").resolve()
print(lib_dir)
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
)
from xautodl.log_utils import time_string
from xautodl.log_utils import AverageMeter, convert_secs2time
from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, Top1AccMetric
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.models.xcore import get_model
from xautodl.xlayers import super_core
class MAML:
"""A LFNA meta-model that uses the MLP as delta-net."""
def __init__(
self, network, criterion, epochs, meta_lr, inner_lr=0.01, inner_step=1
):
self.criterion = criterion
self.network = network
self.meta_optimizer = torch.optim.Adam(
self.network.parameters(), lr=meta_lr, amsgrad=True
)
self.inner_lr = inner_lr
self.inner_step = inner_step
self._best_info = dict(state_dict=None, iepoch=None, score=None)
print("There are {:} weights.".format(self.network.get_w_container().numel()))
def adapt(self, x, y):
# create a container for the future timestamp
container = self.network.get_w_container()
for k in range(0, self.inner_step):
y_hat = self.network.forward_with_container(x, container)
loss = self.criterion(y_hat, y)
grads = torch.autograd.grad(loss, container.parameters())
container = container.additive([-self.inner_lr * grad for grad in grads])
return container
def predict(self, x, container=None):
if container is not None:
y_hat = self.network.forward_with_container(x, container)
else:
y_hat = self.network(x)
return y_hat
def step(self):
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0)
self.meta_optimizer.step()
def zero_grad(self):
self.meta_optimizer.zero_grad()
def load_state_dict(self, state_dict):
self.criterion.load_state_dict(state_dict["criterion"])
self.network.load_state_dict(state_dict["network"])
self.meta_optimizer.load_state_dict(state_dict["meta_optimizer"])
def state_dict(self):
state_dict = dict()
state_dict["criterion"] = self.criterion.state_dict()
state_dict["network"] = self.network.state_dict()
state_dict["meta_optimizer"] = self.meta_optimizer.state_dict()
return state_dict
def save_best(self, score):
success, best_score = self.network.save_best(score)
return success, best_score
def load_best(self):
self.network.load_best()
def main(args):
prepare_seed(args.rand_seed)
logger = prepare_logger(args)
train_env = get_synthetic_env(mode="train", version=args.env_version)
valid_env = get_synthetic_env(mode="valid", version=args.env_version)
trainval_env = get_synthetic_env(mode="trainval", version=args.env_version)
test_env = get_synthetic_env(mode="test", version=args.env_version)
all_env = get_synthetic_env(mode=None, version=args.env_version)
logger.log("The training enviornment: {:}".format(train_env))
logger.log("The validation enviornment: {:}".format(valid_env))
logger.log("The trainval enviornment: {:}".format(trainval_env))
logger.log("The total enviornment: {:}".format(all_env))
logger.log("The test enviornment: {:}".format(test_env))
model_kwargs = dict(
config=dict(model_type="norm_mlp"),
input_dim=all_env.meta_info["input_dim"],
output_dim=all_env.meta_info["output_dim"],
hidden_dims=[args.hidden_dim] * 2,
act_cls="relu",
norm_cls="layer_norm_1d",
)
model = get_model(**model_kwargs)
model = model.to(args.device)
if all_env.meta_info["task"] == "regression":
criterion = torch.nn.MSELoss()
metric_cls = MSEMetric
elif all_env.meta_info["task"] == "classification":
criterion = torch.nn.CrossEntropyLoss()
metric_cls = Top1AccMetric
else:
raise ValueError(
"This task ({:}) is not supported.".format(all_env.meta_info["task"])
)
maml = MAML(
model, criterion, args.epochs, args.meta_lr, args.inner_lr, args.inner_step
)
# meta-training
last_success_epoch = 0
per_epoch_time, start_time = AverageMeter(), time.time()
for iepoch in range(args.epochs):
need_time = "Time Left: {:}".format(
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
)
head_str = (
"[{:}] [{:04d}/{:04d}] ".format(time_string(), iepoch, args.epochs)
+ need_time
)
maml.zero_grad()
meta_losses = []
for ibatch in range(args.meta_batch):
future_idx = random.randint(0, len(trainval_env) - 1)
future_t, (future_x, future_y) = trainval_env[future_idx]
# -->>
seq_times = trainval_env.get_seq_times(future_idx, args.seq_length)
_, (allxs, allys) = trainval_env.seq_call(seq_times)
allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1)
if trainval_env.meta_info["task"] == "classification":
allys = allys.view(-1)
historical_x, historical_y = allxs.to(args.device), allys.to(args.device)
future_container = maml.adapt(historical_x, historical_y)
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = maml.predict(future_x, future_container)
future_loss = maml.criterion(future_y_hat, future_y)
meta_losses.append(future_loss)
meta_loss = torch.stack(meta_losses).mean()
meta_loss.backward()
maml.step()
logger.log(head_str + " meta-loss: {:.4f}".format(meta_loss.item()))
success, best_score = maml.save_best(-meta_loss.item())
if success:
logger.log("Achieve the best with best_score = {:.3f}".format(best_score))
save_checkpoint(maml.state_dict(), logger.path("model"), logger)
last_success_epoch = iepoch
if iepoch - last_success_epoch >= args.early_stop_thresh:
logger.log("Early stop at {:}".format(iepoch))
break
per_epoch_time.update(time.time() - start_time)
start_time = time.time()
# meta-test
maml.load_best()
def finetune(index):
seq_times = test_env.get_seq_times(index, args.seq_length)
_, (allxs, allys) = test_env.seq_call(seq_times)
allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1)
if test_env.meta_info["task"] == "classification":
allys = allys.view(-1)
historical_x, historical_y = allxs.to(args.device), allys.to(args.device)
future_container = maml.adapt(historical_x, historical_y)
historical_y_hat = maml.predict(historical_x, future_container)
train_metric = metric_cls(True)
# model.analyze_weights()
with torch.no_grad():
train_metric(historical_y_hat, historical_y)
train_results = train_metric.get_info()
return train_results, future_container
train_results, future_container = finetune(0)
metric = metric_cls(True)
per_timestamp_time, start_time = AverageMeter(), time.time()
for idx, (future_time, (future_x, future_y)) in enumerate(test_env):
need_time = "Time Left: {:}".format(
convert_secs2time(per_timestamp_time.avg * (len(test_env) - idx), True)
)
logger.log(
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, len(test_env))
+ " "
+ need_time
)
# build optimizer
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = maml.predict(future_x, future_container)
future_loss = criterion(future_y_hat, future_y)
metric(future_y_hat, future_y)
log_str = (
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, len(test_env))
+ " train-score: {:.5f}, eval-score: {:.5f}".format(
train_results["score"], metric.get_info()["score"]
)
)
logger.log(log_str)
logger.log("")
per_timestamp_time.update(time.time() - start_time)
start_time = time.time()
logger.log("-" * 200 + "\n")
logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Use the maml.")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/GeMOSA-synthetic/use-maml-nft",
help="The checkpoint directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
parser.add_argument(
"--hidden_dim",
type=int,
default=16,
help="The hidden dimension.",
)
parser.add_argument(
"--meta_lr",
type=float,
default=0.02,
help="The learning rate for the MAML optimizer (default is Adam)",
)
parser.add_argument(
"--inner_lr",
type=float,
default=0.005,
help="The learning rate for the inner optimization",
)
parser.add_argument(
"--inner_step", type=int, default=1, help="The inner loop steps for MAML."
)
parser.add_argument(
"--seq_length", type=int, default=20, help="The sequence length."
)
parser.add_argument(
"--meta_batch",
type=int,
default=256,
help="The batch size for the meta-model",
)
parser.add_argument(
"--epochs",
type=int,
default=2000,
help="The total number of epochs.",
)
parser.add_argument(
"--early_stop_thresh",
type=int,
default=50,
help="The maximum epochs for early stop.",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="",
)
parser.add_argument(
"--workers",
type=int,
default=4,
help="The number of data loading workers (default: 4)",
)
# Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-s{:}-mlr{:}-d{:}-e{:}-env{:}".format(
args.save_dir,
args.inner_step,
args.meta_lr,
args.hidden_dim,
args.epochs,
args.env_version,
)
main(args)

View File

@@ -0,0 +1,228 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/GeMOSA/baselines/slbm-ft.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda
# python exps/GeMOSA/baselines/slbm-ft.py --env_version v2 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda
# python exps/GeMOSA/baselines/slbm-ft.py --env_version v3 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda
# python exps/GeMOSA/baselines/slbm-ft.py --env_version v4 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "..").resolve()
print("LIB-DIR: {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
)
from xautodl.log_utils import time_string
from xautodl.log_utils import AverageMeter, convert_secs2time
from xautodl.procedures.metric_utils import (
SaveMetric,
MSEMetric,
Top1AccMetric,
ComposeMetric,
)
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.models.xcore import get_model
from xautodl.utils import show_mean_var
def subsample(historical_x, historical_y, maxn=10000):
total = historical_x.size(0)
if total <= maxn:
return historical_x, historical_y
else:
indexes = torch.randint(low=0, high=total, size=[maxn])
return historical_x[indexes], historical_y[indexes]
def main(args):
prepare_seed(args.rand_seed)
logger = prepare_logger(args)
env = get_synthetic_env(mode="test", version=args.env_version)
model_kwargs = dict(
config=dict(model_type="norm_mlp"),
input_dim=env.meta_info["input_dim"],
output_dim=env.meta_info["output_dim"],
hidden_dims=[args.hidden_dim] * 2,
act_cls="relu",
norm_cls="layer_norm_1d",
)
logger.log("The total enviornment: {:}".format(env))
w_containers = dict()
if env.meta_info["task"] == "regression":
criterion = torch.nn.MSELoss()
metric_cls = MSEMetric
elif env.meta_info["task"] == "classification":
criterion = torch.nn.CrossEntropyLoss()
metric_cls = Top1AccMetric
else:
raise ValueError(
"This task ({:}) is not supported.".format(all_env.meta_info["task"])
)
def finetune(index):
seq_times = env.get_seq_times(index, args.seq_length)
_, (allxs, allys) = env.seq_call(seq_times)
allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1)
if env.meta_info["task"] == "classification":
allys = allys.view(-1)
historical_x, historical_y = allxs.to(args.device), allys.to(args.device)
model = get_model(**model_kwargs)
model = model.to(args.device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
int(args.epochs * 0.25),
int(args.epochs * 0.5),
int(args.epochs * 0.75),
],
gamma=0.3,
)
train_metric = metric_cls(True)
best_loss, best_param = None, None
for _iepoch in range(args.epochs):
preds = model(historical_x)
optimizer.zero_grad()
loss = criterion(preds, historical_y)
loss.backward()
optimizer.step()
lr_scheduler.step()
# save best
if best_loss is None or best_loss > loss.item():
best_loss = loss.item()
best_param = copy.deepcopy(model.state_dict())
model.load_state_dict(best_param)
# model.analyze_weights()
with torch.no_grad():
train_metric(preds, historical_y)
train_results = train_metric.get_info()
return train_results, model
metric = metric_cls(True)
per_timestamp_time, start_time = AverageMeter(), time.time()
for idx, (future_time, (future_x, future_y)) in enumerate(env):
need_time = "Time Left: {:}".format(
convert_secs2time(per_timestamp_time.avg * (len(env) - idx), True)
)
logger.log(
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, len(env))
+ " "
+ need_time
)
# train the same data
train_results, model = finetune(idx)
# build optimizer
xmetric = ComposeMetric(metric_cls(True), SaveMetric())
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = model(future_x)
future_loss = criterion(future_y_hat, future_y)
metric(future_y_hat, future_y)
log_str = (
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, len(env))
+ " train-score: {:.5f}, eval-score: {:.5f}".format(
train_results["score"], metric.get_info()["score"]
)
)
logger.log(log_str)
logger.log("")
per_timestamp_time.update(time.time() - start_time)
start_time = time.time()
save_checkpoint(
{"w_containers": w_containers},
logger.path(None) / "final-ckp.pth",
logger,
)
logger.log("-" * 200 + "\n")
logger.close()
return metric.get_info()["score"]
if __name__ == "__main__":
parser = argparse.ArgumentParser("Use the data in the past.")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/GeMOSA-synthetic/use-same-ft-timestamp",
help="The checkpoint directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
parser.add_argument(
"--hidden_dim",
type=int,
required=True,
help="The hidden dimension.",
)
parser.add_argument(
"--init_lr",
type=float,
default=0.1,
help="The initial learning rate for the optimizer (default is Adam)",
)
parser.add_argument(
"--seq_length", type=int, default=20, help="The sequence length."
)
parser.add_argument(
"--batch_size",
type=int,
default=512,
help="The batch size",
)
parser.add_argument(
"--epochs",
type=int,
default=300,
help="The total number of epochs.",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="",
)
parser.add_argument(
"--workers",
type=int,
default=4,
help="The number of data loading workers (default: 4)",
)
# Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
args.save_dir = "{:}-d{:}_e{:}_lr{:}-env{:}".format(
args.save_dir, args.hidden_dim, args.epochs, args.init_lr, args.env_version
)
if args.rand_seed is None or args.rand_seed < 0:
results = []
for iseed in range(3):
args.rand_seed = random.randint(1, 100000)
result = main(args)
results.append(result)
show_mean_var(results)
else:
main(args)

View File

@@ -0,0 +1,227 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/GeMOSA/baselines/slbm-nof.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda
# python exps/GeMOSA/baselines/slbm-nof.py --env_version v2 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda
# python exps/GeMOSA/baselines/slbm-nof.py --env_version v3 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda
# python exps/GeMOSA/baselines/slbm-nof.py --env_version v4 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "..").resolve()
print("LIB-DIR: {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
)
from xautodl.log_utils import time_string
from xautodl.log_utils import AverageMeter, convert_secs2time
from xautodl.procedures.metric_utils import (
SaveMetric,
MSEMetric,
Top1AccMetric,
ComposeMetric,
)
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.models.xcore import get_model
from xautodl.utils import show_mean_var
def subsample(historical_x, historical_y, maxn=10000):
total = historical_x.size(0)
if total <= maxn:
return historical_x, historical_y
else:
indexes = torch.randint(low=0, high=total, size=[maxn])
return historical_x[indexes], historical_y[indexes]
def main(args):
prepare_seed(args.rand_seed)
logger = prepare_logger(args)
env = get_synthetic_env(mode="test", version=args.env_version)
model_kwargs = dict(
config=dict(model_type="norm_mlp"),
input_dim=env.meta_info["input_dim"],
output_dim=env.meta_info["output_dim"],
hidden_dims=[args.hidden_dim] * 2,
act_cls="relu",
norm_cls="layer_norm_1d",
)
logger.log("The total enviornment: {:}".format(env))
w_containers = dict()
if env.meta_info["task"] == "regression":
criterion = torch.nn.MSELoss()
metric_cls = MSEMetric
elif env.meta_info["task"] == "classification":
criterion = torch.nn.CrossEntropyLoss()
metric_cls = Top1AccMetric
else:
raise ValueError(
"This task ({:}) is not supported.".format(all_env.meta_info["task"])
)
seq_times = env.get_seq_times(0, args.seq_length)
_, (allxs, allys) = env.seq_call(seq_times)
allxs, allys = allxs.view(-1, allxs.shape[-1]), allys.view(-1, 1)
if env.meta_info["task"] == "classification":
allys = allys.view(-1)
historical_x, historical_y = allxs.to(args.device), allys.to(args.device)
model = get_model(**model_kwargs)
model = model.to(args.device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
int(args.epochs * 0.25),
int(args.epochs * 0.5),
int(args.epochs * 0.75),
],
gamma=0.3,
)
train_metric = metric_cls(True)
best_loss, best_param = None, None
for _iepoch in range(args.epochs):
preds = model(historical_x)
optimizer.zero_grad()
loss = criterion(preds, historical_y)
loss.backward()
optimizer.step()
lr_scheduler.step()
# save best
if best_loss is None or best_loss > loss.item():
best_loss = loss.item()
best_param = copy.deepcopy(model.state_dict())
model.load_state_dict(best_param)
model.analyze_weights()
with torch.no_grad():
train_metric(preds, historical_y)
train_results = train_metric.get_info()
print(train_results)
metric = metric_cls(True)
per_timestamp_time, start_time = AverageMeter(), time.time()
for idx, (future_time, (future_x, future_y)) in enumerate(env):
need_time = "Time Left: {:}".format(
convert_secs2time(per_timestamp_time.avg * (len(env) - idx), True)
)
logger.log(
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, len(env))
+ " "
+ need_time
)
# train the same data
# build optimizer
xmetric = ComposeMetric(metric_cls(True), SaveMetric())
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = model(future_x)
future_loss = criterion(future_y_hat, future_y)
metric(future_y_hat, future_y)
log_str = (
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, len(env))
+ " train-score: {:.5f}, eval-score: {:.5f}".format(
train_results["score"], metric.get_info()["score"]
)
)
logger.log(log_str)
logger.log("")
per_timestamp_time.update(time.time() - start_time)
start_time = time.time()
save_checkpoint(
{"w_containers": w_containers},
logger.path(None) / "final-ckp.pth",
logger,
)
logger.log("-" * 200 + "\n")
logger.close()
return metric.get_info()["score"]
if __name__ == "__main__":
parser = argparse.ArgumentParser("Use the data in the past.")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/GeMOSA-synthetic/use-same-nof-timestamp",
help="The checkpoint directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
parser.add_argument(
"--hidden_dim",
type=int,
required=True,
help="The hidden dimension.",
)
parser.add_argument(
"--seq_length", type=int, default=20, help="The sequence length."
)
parser.add_argument(
"--init_lr",
type=float,
default=0.1,
help="The initial learning rate for the optimizer (default is Adam)",
)
parser.add_argument(
"--batch_size",
type=int,
default=512,
help="The batch size",
)
parser.add_argument(
"--epochs",
type=int,
default=300,
help="The total number of epochs.",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="",
)
parser.add_argument(
"--workers",
type=int,
default=4,
help="The number of data loading workers (default: 4)",
)
# Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
args.save_dir = "{:}-d{:}_e{:}_lr{:}-env{:}".format(
args.save_dir, args.hidden_dim, args.epochs, args.init_lr, args.env_version
)
if args.rand_seed is None or args.rand_seed < 0:
results = []
for iseed in range(3):
args.rand_seed = random.randint(1, 100000)
result = main(args)
results.append(result)
show_mean_var(results)
else:
main(args)

View File

@@ -0,0 +1,206 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/LFNA/basic-his.py --srange 1-999 --env_version v1 --hidden_dim 16
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
)
from xautodl.log_utils import time_string
from xautodl.log_utils import AverageMeter, convert_secs2time
from xautodl.utils import split_str2indexes
from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn
from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.models.xcore import get_model
from lfna_utils import lfna_setup
def subsample(historical_x, historical_y, maxn=10000):
total = historical_x.size(0)
if total <= maxn:
return historical_x, historical_y
else:
indexes = torch.randint(low=0, high=total, size=[maxn])
return historical_x[indexes], historical_y[indexes]
def main(args):
logger, env_info, model_kwargs = lfna_setup(args)
# check indexes to be evaluated
to_evaluate_indexes = split_str2indexes(args.srange, env_info["total"], None)
logger.log(
"Evaluate {:}, which has {:} timestamps in total.".format(
args.srange, len(to_evaluate_indexes)
)
)
w_container_per_epoch = dict()
per_timestamp_time, start_time = AverageMeter(), time.time()
for i, idx in enumerate(to_evaluate_indexes):
need_time = "Time Left: {:}".format(
convert_secs2time(
per_timestamp_time.avg * (len(to_evaluate_indexes) - i), True
)
)
logger.log(
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}][{:04d}]".format(i, len(to_evaluate_indexes), idx)
+ " "
+ need_time
)
# train the same data
assert idx != 0
historical_x, historical_y = [], []
for past_i in range(idx):
historical_x.append(env_info["{:}-x".format(past_i)])
historical_y.append(env_info["{:}-y".format(past_i)])
historical_x, historical_y = torch.cat(historical_x), torch.cat(historical_y)
historical_x, historical_y = subsample(historical_x, historical_y)
# build model
model = get_model(dict(model_type="simple_mlp"), **model_kwargs)
# build optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
criterion = torch.nn.MSELoss()
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
int(args.epochs * 0.25),
int(args.epochs * 0.5),
int(args.epochs * 0.75),
],
gamma=0.3,
)
train_metric = MSEMetric()
best_loss, best_param = None, None
for _iepoch in range(args.epochs):
preds = model(historical_x)
optimizer.zero_grad()
loss = criterion(preds, historical_y)
loss.backward()
optimizer.step()
lr_scheduler.step()
# save best
if best_loss is None or best_loss > loss.item():
best_loss = loss.item()
best_param = copy.deepcopy(model.state_dict())
model.load_state_dict(best_param)
with torch.no_grad():
train_metric(preds, historical_y)
train_results = train_metric.get_info()
metric = ComposeMetric(MSEMetric(), SaveMetric())
eval_dataset = torch.utils.data.TensorDataset(
env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)]
)
eval_loader = torch.utils.data.DataLoader(
eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0
)
results = basic_eval_fn(eval_loader, model, metric, logger)
log_str = (
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, env_info["total"])
+ " train-mse: {:.5f}, eval-mse: {:.5f}".format(
train_results["mse"], results["mse"]
)
)
logger.log(log_str)
save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(
idx, env_info["total"]
)
w_container_per_epoch[idx] = model.get_w_container().no_grad_clone()
save_checkpoint(
{
"model_state_dict": model.state_dict(),
"model": model,
"index": idx,
"timestamp": env_info["{:}-timestamp".format(idx)],
},
save_path,
logger,
)
logger.log("")
per_timestamp_time.update(time.time() - start_time)
start_time = time.time()
save_checkpoint(
{"w_container_per_epoch": w_container_per_epoch},
logger.path(None) / "final-ckp.pth",
logger,
)
logger.log("-" * 200 + "\n")
logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Use all the past data to train.")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/lfna-synthetic/use-all-past-data",
help="The checkpoint directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
parser.add_argument(
"--hidden_dim",
type=int,
required=True,
help="The hidden dimension.",
)
parser.add_argument(
"--init_lr",
type=float,
default=0.1,
help="The initial learning rate for the optimizer (default is Adam)",
)
parser.add_argument(
"--batch_size",
type=int,
default=512,
help="The batch size",
)
parser.add_argument(
"--epochs",
type=int,
default=1000,
help="The total number of epochs.",
)
parser.add_argument(
"--srange", type=str, required=True, help="The range of models to be evaluated"
)
parser.add_argument(
"--workers",
type=int,
default=4,
help="The number of data loading workers (default: 4)",
)
# Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-{:}-d{:}".format(
args.save_dir, args.env_version, args.hidden_dim
)
main(args)

View File

@@ -0,0 +1,207 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/GeMOSA/basic-prev.py --env_version v1 --prev_time 5 --hidden_dim 16 --epochs 500 --init_lr 0.1
# python exps/GeMOSA/basic-prev.py --env_version v2 --hidden_dim 16 --epochs 1000 --init_lr 0.05
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
)
from xautodl.log_utils import time_string
from xautodl.log_utils import AverageMeter, convert_secs2time
from xautodl.utils import split_str2indexes
from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn
from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.models.xcore import get_model
from lfna_utils import lfna_setup
def subsample(historical_x, historical_y, maxn=10000):
total = historical_x.size(0)
if total <= maxn:
return historical_x, historical_y
else:
indexes = torch.randint(low=0, high=total, size=[maxn])
return historical_x[indexes], historical_y[indexes]
def main(args):
logger, model_kwargs = lfna_setup(args)
w_containers = dict()
per_timestamp_time, start_time = AverageMeter(), time.time()
for idx in range(args.prev_time, env_info["total"]):
need_time = "Time Left: {:}".format(
convert_secs2time(per_timestamp_time.avg * (env_info["total"] - idx), True)
)
logger.log(
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, env_info["total"])
+ " "
+ need_time
)
# train the same data
historical_x = env_info["{:}-x".format(idx - args.prev_time)]
historical_y = env_info["{:}-y".format(idx - args.prev_time)]
# build model
model = get_model(**model_kwargs)
print(model)
# build optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
criterion = torch.nn.MSELoss()
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
int(args.epochs * 0.25),
int(args.epochs * 0.5),
int(args.epochs * 0.75),
],
gamma=0.3,
)
train_metric = MSEMetric()
best_loss, best_param = None, None
for _iepoch in range(args.epochs):
preds = model(historical_x)
optimizer.zero_grad()
loss = criterion(preds, historical_y)
loss.backward()
optimizer.step()
lr_scheduler.step()
# save best
if best_loss is None or best_loss > loss.item():
best_loss = loss.item()
best_param = copy.deepcopy(model.state_dict())
model.load_state_dict(best_param)
model.analyze_weights()
with torch.no_grad():
train_metric(preds, historical_y)
train_results = train_metric.get_info()
metric = ComposeMetric(MSEMetric(), SaveMetric())
eval_dataset = torch.utils.data.TensorDataset(
env_info["{:}-x".format(idx)], env_info["{:}-y".format(idx)]
)
eval_loader = torch.utils.data.DataLoader(
eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0
)
results = basic_eval_fn(eval_loader, model, metric, logger)
log_str = (
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, env_info["total"])
+ " train-mse: {:.5f}, eval-mse: {:.5f}".format(
train_results["mse"], results["mse"]
)
)
logger.log(log_str)
save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(
idx, env_info["total"]
)
w_containers[idx] = model.get_w_container().no_grad_clone()
save_checkpoint(
{
"model_state_dict": model.state_dict(),
"model": model,
"index": idx,
"timestamp": env_info["{:}-timestamp".format(idx)],
},
save_path,
logger,
)
logger.log("")
per_timestamp_time.update(time.time() - start_time)
start_time = time.time()
save_checkpoint(
{"w_containers": w_containers},
logger.path(None) / "final-ckp.pth",
logger,
)
logger.log("-" * 200 + "\n")
logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Use the data in the last timestamp.")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/lfna-synthetic/use-prev-timestamp",
help="The checkpoint directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
parser.add_argument(
"--hidden_dim",
type=int,
required=True,
help="The hidden dimension.",
)
parser.add_argument(
"--init_lr",
type=float,
default=0.1,
help="The initial learning rate for the optimizer (default is Adam)",
)
parser.add_argument(
"--prev_time",
type=int,
default=5,
help="The gap between prev_time and current_timestamp",
)
parser.add_argument(
"--batch_size",
type=int,
default=512,
help="The batch size",
)
parser.add_argument(
"--epochs",
type=int,
default=300,
help="The total number of epochs.",
)
parser.add_argument(
"--workers",
type=int,
default=4,
help="The number of data loading workers (default: 4)",
)
# Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-d{:}_e{:}_lr{:}-prev{:}-env{:}".format(
args.save_dir,
args.hidden_dim,
args.epochs,
args.init_lr,
args.prev_time,
args.env_version,
)
main(args)

View File

@@ -0,0 +1,228 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/GeMOSA/basic-same.py --env_version v1 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda
# python exps/GeMOSA/basic-same.py --env_version v2 --hidden_dim 16 --epochs 500 --init_lr 0.1 --device cuda
# python exps/GeMOSA/basic-same.py --env_version v3 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda
# python exps/GeMOSA/basic-same.py --env_version v4 --hidden_dim 32 --epochs 1000 --init_lr 0.05 --device cuda
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
print("LIB-DIR: {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
)
from xautodl.log_utils import time_string
from xautodl.log_utils import AverageMeter, convert_secs2time
from xautodl.utils import split_str2indexes
from xautodl.procedures.metric_utils import (
SaveMetric,
MSEMetric,
Top1AccMetric,
ComposeMetric,
)
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.models.xcore import get_model
def subsample(historical_x, historical_y, maxn=10000):
total = historical_x.size(0)
if total <= maxn:
return historical_x, historical_y
else:
indexes = torch.randint(low=0, high=total, size=[maxn])
return historical_x[indexes], historical_y[indexes]
def main(args):
prepare_seed(args.rand_seed)
logger = prepare_logger(args)
env = get_synthetic_env(mode=None, version=args.env_version)
model_kwargs = dict(
config=dict(model_type="norm_mlp"),
input_dim=env.meta_info["input_dim"],
output_dim=env.meta_info["output_dim"],
hidden_dims=[args.hidden_dim] * 2,
act_cls="relu",
norm_cls="layer_norm_1d",
)
logger.log("The total enviornment: {:}".format(env))
w_containers = dict()
if env.meta_info["task"] == "regression":
criterion = torch.nn.MSELoss()
metric_cls = MSEMetric
elif env.meta_info["task"] == "classification":
criterion = torch.nn.CrossEntropyLoss()
metric_cls = Top1AccMetric
else:
raise ValueError(
"This task ({:}) is not supported.".format(all_env.meta_info["task"])
)
per_timestamp_time, start_time = AverageMeter(), time.time()
for idx, (future_time, (future_x, future_y)) in enumerate(env):
need_time = "Time Left: {:}".format(
convert_secs2time(per_timestamp_time.avg * (len(env) - idx), True)
)
logger.log(
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, len(env))
+ " "
+ need_time
)
# train the same data
historical_x = future_x.to(args.device)
historical_y = future_y.to(args.device)
# build model
model = get_model(**model_kwargs)
model = model.to(args.device)
if idx == 0:
print(model)
# build optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
int(args.epochs * 0.25),
int(args.epochs * 0.5),
int(args.epochs * 0.75),
],
gamma=0.3,
)
train_metric = metric_cls(True)
best_loss, best_param = None, None
for _iepoch in range(args.epochs):
preds = model(historical_x)
optimizer.zero_grad()
loss = criterion(preds, historical_y)
loss.backward()
optimizer.step()
lr_scheduler.step()
# save best
if best_loss is None or best_loss > loss.item():
best_loss = loss.item()
best_param = copy.deepcopy(model.state_dict())
model.load_state_dict(best_param)
model.analyze_weights()
with torch.no_grad():
train_metric(preds, historical_y)
train_results = train_metric.get_info()
xmetric = ComposeMetric(metric_cls(True), SaveMetric())
eval_dataset = torch.utils.data.TensorDataset(
future_x.to(args.device), future_y.to(args.device)
)
eval_loader = torch.utils.data.DataLoader(
eval_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0
)
results = basic_eval_fn(eval_loader, model, xmetric, logger)
log_str = (
"[{:}]".format(time_string())
+ " [{:04d}/{:04d}]".format(idx, len(env))
+ " train-score: {:.5f}, eval-score: {:.5f}".format(
train_results["score"], results["score"]
)
)
logger.log(log_str)
save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(idx, len(env))
w_containers[idx] = model.get_w_container().no_grad_clone()
save_checkpoint(
{
"model_state_dict": model.state_dict(),
"model": model,
"index": idx,
"timestamp": future_time.item(),
},
save_path,
logger,
)
logger.log("")
per_timestamp_time.update(time.time() - start_time)
start_time = time.time()
save_checkpoint(
{"w_containers": w_containers},
logger.path(None) / "final-ckp.pth",
logger,
)
logger.log("-" * 200 + "\n")
logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser("Use the data in the past.")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/GeMOSA-synthetic/use-same-timestamp",
help="The checkpoint directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
parser.add_argument(
"--hidden_dim",
type=int,
required=True,
help="The hidden dimension.",
)
parser.add_argument(
"--init_lr",
type=float,
default=0.1,
help="The initial learning rate for the optimizer (default is Adam)",
)
parser.add_argument(
"--batch_size",
type=int,
default=512,
help="The batch size",
)
parser.add_argument(
"--epochs",
type=int,
default=300,
help="The total number of epochs.",
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="",
)
parser.add_argument(
"--workers",
type=int,
default=4,
help="The number of data loading workers (default: 4)",
)
# Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
args.save_dir = "{:}-d{:}_e{:}_lr{:}-env{:}".format(
args.save_dir, args.hidden_dim, args.epochs, args.init_lr, args.env_version
)
main(args)

View File

@@ -0,0 +1,438 @@
##########################################################
# Learning to Efficiently Generate Models One Step Ahead #
##########################################################
# <----> run on CPU
# python exps/GeMOSA/main.py --env_version v1 --workers 0
# <----> run on a GPU
# python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda
# python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda
# python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda
# python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda
# <----> ablation commands
# python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --ablation old --device cuda
# python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --ablation old --device cuda
# python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --ablation old --device cuda
# python exps/GeMOSA/main.py --env_version v4 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --ablation old --device cuda
##########################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
from torch.nn import functional as F
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
print("LIB-DIR: {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from xautodl.procedures import (
prepare_seed,
prepare_logger,
save_checkpoint,
copy_checkpoint,
)
from xautodl.log_utils import time_string
from xautodl.log_utils import AverageMeter, convert_secs2time
from xautodl.utils import split_str2indexes
from xautodl.procedures.advanced_main import basic_train_fn, basic_eval_fn
from xautodl.procedures.metric_utils import SaveMetric, MSEMetric, ComposeMetric
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.models.xcore import get_model
from xautodl.procedures.metric_utils import MSEMetric, Top1AccMetric
from meta_model import MetaModelV1
from meta_model_ablation import MetaModel_TraditionalAtt
def online_evaluate(
env,
meta_model,
base_model,
criterion,
metric,
args,
logger,
save=False,
easy_adapt=False,
):
logger.log("Online evaluate: {:}".format(env))
metric.reset()
loss_meter = AverageMeter()
w_containers = dict()
for idx, (future_time, (future_x, future_y)) in enumerate(env):
with torch.no_grad():
meta_model.eval()
base_model.eval()
future_time_embed = meta_model.gen_time_embed(
future_time.to(args.device).view(-1)
)
[future_container] = meta_model.gen_model(future_time_embed)
if save:
w_containers[idx] = future_container.no_grad_clone()
future_x, future_y = future_x.to(args.device), future_y.to(args.device)
future_y_hat = base_model.forward_with_container(future_x, future_container)
future_loss = criterion(future_y_hat, future_y)
loss_meter.update(future_loss.item())
# accumulate the metric scores
score = metric(future_y_hat, future_y)
if easy_adapt:
meta_model.easy_adapt(future_time.item(), future_time_embed)
refine, post_refine_loss = False, -1
else:
refine, post_refine_loss = meta_model.adapt(
base_model,
criterion,
future_time.item(),
future_x,
future_y,
args.refine_lr,
args.refine_epochs,
{"param": future_time_embed, "loss": future_loss.item()},
)
logger.log(
"[ONLINE] [{:03d}/{:03d}] loss={:.4f}, score={:.4f}".format(
idx, len(env), future_loss.item(), score
)
+ ", post-loss={:.4f}".format(post_refine_loss if refine else -1)
)
meta_model.clear_fixed()
meta_model.clear_learnt()
return w_containers, loss_meter.avg, metric.get_info()["score"]
def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger):
base_model.train()
meta_model.train()
optimizer = torch.optim.Adam(
meta_model.get_parameters(True, True, True),
lr=args.lr,
weight_decay=args.weight_decay,
amsgrad=True,
)
logger.log("Pre-train the meta-model")
logger.log("Using the optimizer: {:}".format(optimizer))
meta_model.set_best_dir(logger.path(None) / "ckps-pretrain-v2")
final_best_name = "final-pretrain-{:}.pth".format(args.rand_seed)
if meta_model.has_best(final_best_name):
meta_model.load_best(final_best_name)
logger.log("Directly load the best model from {:}".format(final_best_name))
return
total_indexes = list(range(meta_model.meta_length))
meta_model.set_best_name("pretrain-{:}.pth".format(args.rand_seed))
last_success_epoch, early_stop_thresh = 0, args.pretrain_early_stop_thresh
per_epoch_time, start_time = AverageMeter(), time.time()
device = args.device
for iepoch in range(args.epochs):
left_time = "Time Left: {:}".format(
convert_secs2time(per_epoch_time.avg * (args.epochs - iepoch), True)
)
optimizer.zero_grad()
generated_time_embeds = meta_model.gen_time_embed(meta_model.meta_timestamps)
batch_indexes = random.choices(total_indexes, k=args.meta_batch)
raw_time_steps = meta_model.meta_timestamps[batch_indexes]
regularization_loss = F.l1_loss(
generated_time_embeds, meta_model.super_meta_embed, reduction="mean"
)
# future loss
total_future_losses, total_present_losses = [], []
future_containers = meta_model.gen_model(generated_time_embeds[batch_indexes])
present_containers = meta_model.gen_model(
meta_model.super_meta_embed[batch_indexes]
)
for ibatch, time_step in enumerate(raw_time_steps.cpu().tolist()):
_, (inputs, targets) = xenv(time_step)
inputs, targets = inputs.to(device), targets.to(device)
predictions = base_model.forward_with_container(
inputs, future_containers[ibatch]
)
total_future_losses.append(criterion(predictions, targets))
predictions = base_model.forward_with_container(
inputs, present_containers[ibatch]
)
total_present_losses.append(criterion(predictions, targets))
with torch.no_grad():
meta_std = torch.stack(total_future_losses).std().item()
loss_future = torch.stack(total_future_losses).mean()
loss_present = torch.stack(total_present_losses).mean()
total_loss = loss_future + loss_present + regularization_loss
total_loss.backward()
optimizer.step()
# success
success, best_score = meta_model.save_best(-total_loss.item())
logger.log(
"{:} [META {:04d}/{:}] loss : {:.4f} +- {:.4f} = {:.4f} + {:.4f} + {:.4f}".format(
time_string(),
iepoch,
args.epochs,
total_loss.item(),
meta_std,
loss_future.item(),
loss_present.item(),
regularization_loss.item(),
)
+ ", batch={:}".format(len(total_future_losses))
+ ", success={:}, best={:.4f}".format(success, -best_score)
+ ", LS={:}/{:}".format(iepoch - last_success_epoch, early_stop_thresh)
+ ", {:}".format(left_time)
)
if success:
last_success_epoch = iepoch
if iepoch - last_success_epoch >= early_stop_thresh:
logger.log("Early stop the pre-training at {:}".format(iepoch))
break
per_epoch_time.update(time.time() - start_time)
start_time = time.time()
meta_model.load_best()
# save to the final model
meta_model.set_best_name(final_best_name)
success, _ = meta_model.save_best(best_score + 1e-6)
assert success
logger.log("Save the best model into {:}".format(final_best_name))
def main(args):
prepare_seed(args.rand_seed)
logger = prepare_logger(args)
train_env = get_synthetic_env(mode="train", version=args.env_version)
valid_env = get_synthetic_env(mode="valid", version=args.env_version)
trainval_env = get_synthetic_env(mode="trainval", version=args.env_version)
test_env = get_synthetic_env(mode="test", version=args.env_version)
all_env = get_synthetic_env(mode=None, version=args.env_version)
logger.log("The training enviornment: {:}".format(train_env))
logger.log("The validation enviornment: {:}".format(valid_env))
logger.log("The trainval enviornment: {:}".format(trainval_env))
logger.log("The total enviornment: {:}".format(all_env))
logger.log("The test enviornment: {:}".format(test_env))
model_kwargs = dict(
config=dict(model_type="norm_mlp"),
input_dim=all_env.meta_info["input_dim"],
output_dim=all_env.meta_info["output_dim"],
hidden_dims=[args.hidden_dim] * 2,
act_cls="relu",
norm_cls="layer_norm_1d",
)
base_model = get_model(**model_kwargs)
base_model = base_model.to(args.device)
if all_env.meta_info["task"] == "regression":
criterion = torch.nn.MSELoss()
metric = MSEMetric(True)
elif all_env.meta_info["task"] == "classification":
criterion = torch.nn.CrossEntropyLoss()
metric = Top1AccMetric(True)
else:
raise ValueError(
"This task ({:}) is not supported.".format(all_env.meta_info["task"])
)
shape_container = base_model.get_w_container().to_shape_container()
# pre-train the hypernetwork
timestamps = trainval_env.get_timestamp(None)
if args.ablation is None:
MetaModel_cls = MetaModelV1
elif args.ablation == "old":
MetaModel_cls = MetaModel_TraditionalAtt
else:
raise ValueError("Unknown ablation : {:}".format(args.ablation))
meta_model = MetaModel_cls(
shape_container,
args.layer_dim,
args.time_dim,
timestamps,
seq_length=args.seq_length,
interval=trainval_env.time_interval,
)
meta_model = meta_model.to(args.device)
logger.log("The base-model has {:} weights.".format(base_model.numel()))
logger.log("The meta-model has {:} weights.".format(meta_model.numel()))
logger.log("The base-model is\n{:}".format(base_model))
logger.log("The meta-model is\n{:}".format(meta_model))
meta_train_procedure(base_model, meta_model, criterion, trainval_env, args, logger)
# try to evaluate once
# online_evaluate(train_env, meta_model, base_model, criterion, args, logger)
# online_evaluate(valid_env, meta_model, base_model, criterion, args, logger)
"""
w_containers, loss_meter = online_evaluate(
all_env, meta_model, base_model, criterion, args, logger, True
)
logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter))
"""
w_containers_care_adapt, loss_adapt_v1, metric_adapt_v1 = online_evaluate(
test_env, meta_model, base_model, criterion, metric, args, logger, True, False
)
w_containers_easy_adapt, loss_adapt_v2, metric_adapt_v2 = online_evaluate(
test_env, meta_model, base_model, criterion, metric, args, logger, True, True
)
logger.log(
"[Refine-Adapt] loss = {:.6f}, metric = {:.6f}".format(
loss_adapt_v1, metric_adapt_v1
)
)
logger.log(
"[Easy-Adapt] loss = {:.6f}, metric = {:.6f}".format(
loss_adapt_v2, metric_adapt_v2
)
)
save_checkpoint(
{
"w_containers_care_adapt": w_containers_care_adapt,
"w_containers_easy_adapt": w_containers_easy_adapt,
"test_loss_adapt_v1": loss_adapt_v1,
"test_loss_adapt_v2": loss_adapt_v2,
"test_metric_adapt_v1": metric_adapt_v1,
"test_metric_adapt_v2": metric_adapt_v2,
},
logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed),
logger,
)
logger.log("-" * 200 + "\n")
logger.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser(".")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/GeMOSA-synthetic/GeMOSA",
help="The checkpoint directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
parser.add_argument(
"--hidden_dim",
type=int,
default=16,
help="The hidden dimension.",
)
parser.add_argument(
"--layer_dim",
type=int,
default=16,
help="The layer chunk dimension.",
)
parser.add_argument(
"--time_dim",
type=int,
default=16,
help="The timestamp dimension.",
)
#####
parser.add_argument(
"--lr",
type=float,
default=0.002,
help="The initial learning rate for the optimizer (default is Adam)",
)
parser.add_argument(
"--weight_decay",
type=float,
default=0.00001,
help="The weight decay for the optimizer (default is Adam)",
)
parser.add_argument(
"--meta_batch",
type=int,
default=64,
help="The batch size for the meta-model",
)
parser.add_argument(
"--sampler_enlarge",
type=int,
default=5,
help="Enlarge the #iterations for an epoch",
)
parser.add_argument("--epochs", type=int, default=10000, help="The total #epochs.")
parser.add_argument(
"--refine_lr",
type=float,
default=0.001,
help="The learning rate for the optimizer, during refine",
)
parser.add_argument(
"--refine_epochs", type=int, default=150, help="The final refine #epochs."
)
parser.add_argument(
"--early_stop_thresh",
type=int,
default=20,
help="The #epochs for early stop.",
)
parser.add_argument(
"--pretrain_early_stop_thresh",
type=int,
default=300,
help="The #epochs for early stop.",
)
parser.add_argument(
"--seq_length", type=int, default=10, help="The sequence length."
)
parser.add_argument(
"--workers", type=int, default=4, help="The number of workers in parallel."
)
parser.add_argument(
"--ablation", type=str, default=None, help="The ablation indicator."
)
parser.add_argument(
"--device",
type=str,
default="cpu",
help="",
)
# Random Seed
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
args = parser.parse_args()
if args.rand_seed is None or args.rand_seed < 0:
args.rand_seed = random.randint(1, 100000)
assert args.save_dir is not None, "The save dir argument can not be None"
if args.ablation is None:
args.save_dir = "{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-env{:}".format(
args.save_dir,
args.meta_batch,
args.hidden_dim,
args.layer_dim,
args.time_dim,
args.seq_length,
args.lr,
args.weight_decay,
args.epochs,
args.env_version,
)
else:
args.save_dir = (
"{:}-bs{:}-d{:}_{:}_{:}-s{:}-lr{:}-wd{:}-e{:}-ab{:}-env{:}".format(
args.save_dir,
args.meta_batch,
args.hidden_dim,
args.layer_dim,
args.time_dim,
args.seq_length,
args.lr,
args.weight_decay,
args.epochs,
args.ablation,
args.env_version,
)
)
main(args)

View File

@@ -0,0 +1,257 @@
import torch
import torch.nn.functional as F
from xautodl.xlayers import super_core
from xautodl.xlayers import trunc_normal_
from xautodl.xmodels.xcore import get_model
class MetaModelV1(super_core.SuperModule):
"""Learning to Generate Models One Step Ahead (Meta Model Design)."""
def __init__(
self,
shape_container,
layer_dim,
time_dim,
meta_timestamps,
dropout: float = 0.1,
seq_length: int = None,
interval: float = None,
thresh: float = None,
):
super(MetaModelV1, self).__init__()
self._shape_container = shape_container
self._num_layers = len(shape_container)
self._numel_per_layer = []
for ilayer in range(self._num_layers):
self._numel_per_layer.append(shape_container[ilayer].numel())
self._raw_meta_timestamps = meta_timestamps
assert interval is not None
self._interval = interval
self._thresh = interval * seq_length if thresh is None else thresh
self.register_parameter(
"_super_layer_embed",
torch.nn.Parameter(torch.Tensor(self._num_layers, layer_dim)),
)
self.register_parameter(
"_super_meta_embed",
torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_dim)),
)
self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps))
self._time_embed_dim = time_dim
self._append_meta_embed = dict(fixed=None, learnt=None)
self._append_meta_timestamps = dict(fixed=None, learnt=None)
self._tscalar_embed = super_core.SuperDynamicPositionE(
time_dim, scale=1 / interval
)
# build transformer
self._trans_att = super_core.SuperQKVAttentionV2(
qk_att_dim=time_dim,
in_v_dim=time_dim,
hidden_dim=time_dim,
num_heads=4,
proj_dim=time_dim,
qkv_bias=True,
attn_drop=None,
proj_drop=dropout,
)
model_kwargs = dict(
config=dict(model_type="dual_norm_mlp"),
input_dim=layer_dim + time_dim,
output_dim=max(self._numel_per_layer),
hidden_dims=[(layer_dim + time_dim) * 2] * 3,
act_cls="gelu",
norm_cls="layer_norm_1d",
dropout=dropout,
)
self._generator = get_model(**model_kwargs)
# initialization
trunc_normal_(
[self._super_layer_embed, self._super_meta_embed],
std=0.02,
)
def get_parameters(self, time_embed, attention, generator):
parameters = []
if time_embed:
parameters.append(self._super_meta_embed)
if attention:
parameters.extend(list(self._trans_att.parameters()))
if generator:
parameters.append(self._super_layer_embed)
parameters.extend(list(self._generator.parameters()))
return parameters
@property
def meta_timestamps(self):
with torch.no_grad():
meta_timestamps = [self._meta_timestamps]
for key in ("fixed", "learnt"):
if self._append_meta_timestamps[key] is not None:
meta_timestamps.append(self._append_meta_timestamps[key])
return torch.cat(meta_timestamps)
@property
def super_meta_embed(self):
meta_embed = [self._super_meta_embed]
for key in ("fixed", "learnt"):
if self._append_meta_embed[key] is not None:
meta_embed.append(self._append_meta_embed[key])
return torch.cat(meta_embed)
def create_meta_embed(self):
param = torch.Tensor(1, self._time_embed_dim)
trunc_normal_(param, std=0.02)
param = param.to(self._super_meta_embed.device)
param = torch.nn.Parameter(param, True)
return param
def get_closest_meta_distance(self, timestamp):
with torch.no_grad():
distances = torch.abs(self.meta_timestamps - timestamp)
return torch.min(distances).item()
def replace_append_learnt(self, timestamp, meta_embed):
self._append_meta_timestamps["learnt"] = timestamp
self._append_meta_embed["learnt"] = meta_embed
@property
def meta_length(self):
return self.meta_timestamps.numel()
def clear_fixed(self):
self._append_meta_timestamps["fixed"] = None
self._append_meta_embed["fixed"] = None
def clear_learnt(self):
self.replace_append_learnt(None, None)
def append_fixed(self, timestamp, meta_embed):
with torch.no_grad():
device = self._super_meta_embed.device
timestamp = timestamp.detach().clone().to(device)
meta_embed = meta_embed.detach().clone().to(device)
if self._append_meta_timestamps["fixed"] is None:
self._append_meta_timestamps["fixed"] = timestamp
else:
self._append_meta_timestamps["fixed"] = torch.cat(
(self._append_meta_timestamps["fixed"], timestamp), dim=0
)
if self._append_meta_embed["fixed"] is None:
self._append_meta_embed["fixed"] = meta_embed
else:
self._append_meta_embed["fixed"] = torch.cat(
(self._append_meta_embed["fixed"], meta_embed), dim=0
)
def gen_time_embed(self, timestamps):
# timestamps is a batch of timestamps
[B] = timestamps.shape
# batch, seq = timestamps.shape
timestamps = timestamps.view(-1, 1)
meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed
timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
timestamp_qk_att_embed = self._tscalar_embed(
torch.unsqueeze(timestamps, dim=-1) - meta_timestamps
)
# create the mask
mask = (
torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1)
) | (
torch.abs(
torch.unsqueeze(timestamps, dim=-1) - meta_timestamps.view(1, 1, -1)
)
> self._thresh
)
timestamp_embeds = self._trans_att(
timestamp_qk_att_embed,
timestamp_v_embed,
mask,
)
return timestamp_embeds[:, -1, :]
def gen_model(self, time_embeds):
B, _ = time_embeds.shape
# create joint embed
num_layer, _ = self._super_layer_embed.shape
# The shape of `joint_embed` is batch * num-layers * input-dim
joint_embeds = torch.cat(
(
time_embeds.view(B, 1, -1).expand(-1, num_layer, -1),
self._super_layer_embed.view(1, num_layer, -1).expand(B, -1, -1),
),
dim=-1,
)
batch_weights = self._generator(joint_embeds)
batch_containers = []
for weights in torch.split(batch_weights, 1):
batch_containers.append(
self._shape_container.translate(torch.split(weights.squeeze(0), 1))
)
return batch_containers
def forward_raw(self, timestamps, time_embeds, tembed_only=False):
raise NotImplementedError
def forward_candidate(self, input):
raise NotImplementedError
def easy_adapt(self, timestamp, time_embed):
with torch.no_grad():
timestamp = torch.Tensor([timestamp]).to(self._meta_timestamps.device)
self.replace_append_learnt(None, None)
self.append_fixed(timestamp, time_embed)
def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info):
distance = self.get_closest_meta_distance(timestamp)
if distance + self._interval * 1e-2 <= self._interval:
return False, None
x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device)
with torch.set_grad_enabled(True):
new_param = self.create_meta_embed()
optimizer = torch.optim.Adam(
[new_param], lr=lr, weight_decay=1e-5, amsgrad=True
)
timestamp = torch.Tensor([timestamp]).to(new_param.device)
self.replace_append_learnt(timestamp, new_param)
self.train()
base_model.train()
if init_info is not None:
best_loss = init_info["loss"]
new_param.data.copy_(init_info["param"].data)
else:
best_loss = 1e9
with torch.no_grad():
best_new_param = new_param.detach().clone()
for iepoch in range(epochs):
optimizer.zero_grad()
time_embed = self.gen_time_embed(timestamp.view(1))
match_loss = F.l1_loss(new_param, time_embed)
[container] = self.gen_model(new_param.view(1, -1))
y_hat = base_model.forward_with_container(x, container)
meta_loss = criterion(y_hat, y)
loss = meta_loss + match_loss
loss.backward()
optimizer.step()
if meta_loss.item() < best_loss:
with torch.no_grad():
best_loss = meta_loss.item()
best_new_param = new_param.detach().clone()
self.easy_adapt(timestamp, best_new_param)
return True, best_loss
def extra_repr(self) -> str:
return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format(
list(self._super_layer_embed.shape),
list(self._super_meta_embed.shape),
list(self._meta_timestamps.shape),
)

View File

@@ -0,0 +1,260 @@
#
# This is used for the ablation studies:
# The meta-model in this file uses the traditional attention in
# transformer.
#
import torch
import torch.nn.functional as F
from xautodl.xlayers import super_core
from xautodl.xlayers import trunc_normal_
from xautodl.models.xcore import get_model
class MetaModel_TraditionalAtt(super_core.SuperModule):
"""Learning to Generate Models One Step Ahead (Meta Model Design)."""
def __init__(
self,
shape_container,
layer_dim,
time_dim,
meta_timestamps,
dropout: float = 0.1,
seq_length: int = None,
interval: float = None,
thresh: float = None,
):
super(MetaModel_TraditionalAtt, self).__init__()
self._shape_container = shape_container
self._num_layers = len(shape_container)
self._numel_per_layer = []
for ilayer in range(self._num_layers):
self._numel_per_layer.append(shape_container[ilayer].numel())
self._raw_meta_timestamps = meta_timestamps
assert interval is not None
self._interval = interval
self._thresh = interval * seq_length if thresh is None else thresh
self.register_parameter(
"_super_layer_embed",
torch.nn.Parameter(torch.Tensor(self._num_layers, layer_dim)),
)
self.register_parameter(
"_super_meta_embed",
torch.nn.Parameter(torch.Tensor(len(meta_timestamps), time_dim)),
)
self.register_buffer("_meta_timestamps", torch.Tensor(meta_timestamps))
self._time_embed_dim = time_dim
self._append_meta_embed = dict(fixed=None, learnt=None)
self._append_meta_timestamps = dict(fixed=None, learnt=None)
self._tscalar_embed = super_core.SuperDynamicPositionE(
time_dim, scale=1 / interval
)
# build transformer
self._trans_att = super_core.SuperQKVAttention(
in_q_dim=time_dim,
in_k_dim=time_dim,
in_v_dim=time_dim,
num_heads=4,
proj_dim=time_dim,
qkv_bias=True,
attn_drop=None,
proj_drop=dropout,
)
model_kwargs = dict(
config=dict(model_type="dual_norm_mlp"),
input_dim=layer_dim + time_dim,
output_dim=max(self._numel_per_layer),
hidden_dims=[(layer_dim + time_dim) * 2] * 3,
act_cls="gelu",
norm_cls="layer_norm_1d",
dropout=dropout,
)
self._generator = get_model(**model_kwargs)
# initialization
trunc_normal_(
[self._super_layer_embed, self._super_meta_embed],
std=0.02,
)
def get_parameters(self, time_embed, attention, generator):
parameters = []
if time_embed:
parameters.append(self._super_meta_embed)
if attention:
parameters.extend(list(self._trans_att.parameters()))
if generator:
parameters.append(self._super_layer_embed)
parameters.extend(list(self._generator.parameters()))
return parameters
@property
def meta_timestamps(self):
with torch.no_grad():
meta_timestamps = [self._meta_timestamps]
for key in ("fixed", "learnt"):
if self._append_meta_timestamps[key] is not None:
meta_timestamps.append(self._append_meta_timestamps[key])
return torch.cat(meta_timestamps)
@property
def super_meta_embed(self):
meta_embed = [self._super_meta_embed]
for key in ("fixed", "learnt"):
if self._append_meta_embed[key] is not None:
meta_embed.append(self._append_meta_embed[key])
return torch.cat(meta_embed)
def create_meta_embed(self):
param = torch.Tensor(1, self._time_embed_dim)
trunc_normal_(param, std=0.02)
param = param.to(self._super_meta_embed.device)
param = torch.nn.Parameter(param, True)
return param
def get_closest_meta_distance(self, timestamp):
with torch.no_grad():
distances = torch.abs(self.meta_timestamps - timestamp)
return torch.min(distances).item()
def replace_append_learnt(self, timestamp, meta_embed):
self._append_meta_timestamps["learnt"] = timestamp
self._append_meta_embed["learnt"] = meta_embed
@property
def meta_length(self):
return self.meta_timestamps.numel()
def clear_fixed(self):
self._append_meta_timestamps["fixed"] = None
self._append_meta_embed["fixed"] = None
def clear_learnt(self):
self.replace_append_learnt(None, None)
def append_fixed(self, timestamp, meta_embed):
with torch.no_grad():
device = self._super_meta_embed.device
timestamp = timestamp.detach().clone().to(device)
meta_embed = meta_embed.detach().clone().to(device)
if self._append_meta_timestamps["fixed"] is None:
self._append_meta_timestamps["fixed"] = timestamp
else:
self._append_meta_timestamps["fixed"] = torch.cat(
(self._append_meta_timestamps["fixed"], timestamp), dim=0
)
if self._append_meta_embed["fixed"] is None:
self._append_meta_embed["fixed"] = meta_embed
else:
self._append_meta_embed["fixed"] = torch.cat(
(self._append_meta_embed["fixed"], meta_embed), dim=0
)
def gen_time_embed(self, timestamps):
# timestamps is a batch of timestamps
[B] = timestamps.shape
# batch, seq = timestamps.shape
timestamps = timestamps.view(-1, 1)
meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed
timestamp_v_embed = meta_embeds.unsqueeze(dim=0)
timestamp_q_embed = self._tscalar_embed(timestamps)
timestamp_k_embed = self._tscalar_embed(meta_timestamps.view(1, -1))
# create the mask
mask = (
torch.unsqueeze(timestamps, dim=-1) <= meta_timestamps.view(1, 1, -1)
) | (
torch.abs(
torch.unsqueeze(timestamps, dim=-1) - meta_timestamps.view(1, 1, -1)
)
> self._thresh
)
timestamp_embeds = self._trans_att(
timestamp_q_embed, timestamp_k_embed, timestamp_v_embed, mask
)
return timestamp_embeds[:, -1, :]
def gen_model(self, time_embeds):
B, _ = time_embeds.shape
# create joint embed
num_layer, _ = self._super_layer_embed.shape
# The shape of `joint_embed` is batch * num-layers * input-dim
joint_embeds = torch.cat(
(
time_embeds.view(B, 1, -1).expand(-1, num_layer, -1),
self._super_layer_embed.view(1, num_layer, -1).expand(B, -1, -1),
),
dim=-1,
)
batch_weights = self._generator(joint_embeds)
batch_containers = []
for weights in torch.split(batch_weights, 1):
batch_containers.append(
self._shape_container.translate(torch.split(weights.squeeze(0), 1))
)
return batch_containers
def forward_raw(self, timestamps, time_embeds, tembed_only=False):
raise NotImplementedError
def forward_candidate(self, input):
raise NotImplementedError
def easy_adapt(self, timestamp, time_embed):
with torch.no_grad():
timestamp = torch.Tensor([timestamp]).to(self._meta_timestamps.device)
self.replace_append_learnt(None, None)
self.append_fixed(timestamp, time_embed)
def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info):
distance = self.get_closest_meta_distance(timestamp)
if distance + self._interval * 1e-2 <= self._interval:
return False, None
x, y = x.to(self._meta_timestamps.device), y.to(self._meta_timestamps.device)
with torch.set_grad_enabled(True):
new_param = self.create_meta_embed()
optimizer = torch.optim.Adam(
[new_param], lr=lr, weight_decay=1e-5, amsgrad=True
)
timestamp = torch.Tensor([timestamp]).to(new_param.device)
self.replace_append_learnt(timestamp, new_param)
self.train()
base_model.train()
if init_info is not None:
best_loss = init_info["loss"]
new_param.data.copy_(init_info["param"].data)
else:
best_loss = 1e9
with torch.no_grad():
best_new_param = new_param.detach().clone()
for iepoch in range(epochs):
optimizer.zero_grad()
time_embed = self.gen_time_embed(timestamp.view(1))
match_loss = F.l1_loss(new_param, time_embed)
[container] = self.gen_model(new_param.view(1, -1))
y_hat = base_model.forward_with_container(x, container)
meta_loss = criterion(y_hat, y)
loss = meta_loss + match_loss
loss.backward()
optimizer.step()
if meta_loss.item() < best_loss:
with torch.no_grad():
best_loss = meta_loss.item()
best_new_param = new_param.detach().clone()
self.easy_adapt(timestamp, best_new_param)
return True, best_loss
def extra_repr(self) -> str:
return "(_super_layer_embed): {:}, (_super_meta_embed): {:}, (_meta_timestamps): {:}".format(
list(self._super_layer_embed.shape),
list(self._super_meta_embed.shape),
list(self._meta_timestamps.shape),
)

View File

@@ -0,0 +1,441 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
############################################################################
# python exps/GeMOSA/vis-synthetic.py --env_version v1 #
# python exps/GeMOSA/vis-synthetic.py --env_version v2 #
# python exps/GeMOSA/vis-synthetic.py --env_version v3 #
# python exps/GeMOSA/vis-synthetic.py --env_version v4 #
############################################################################
import os, sys, copy, random
import torch
import numpy as np
import argparse
from collections import OrderedDict, defaultdict
from pathlib import Path
from tqdm import tqdm
from pprint import pprint
import matplotlib
from matplotlib import cm
matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from xautodl.models.xcore import get_model
from xautodl.datasets.synthetic_core import get_synthetic_env
from xautodl.procedures.metric_utils import MSEMetric
def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None):
cur_ax.scatter([-100], [-100], color=color, linewidths=linewidths[0], label=label)
cur_ax.scatter(
xs, ys, color=color, alpha=alpha, linewidths=linewidths[1], label=None
)
def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None):
save_path = save_dir / "{:04d}".format(timestamp)
# print('Plot the figure at timestamp-{:} into {:}'.format(timestamp, save_path))
dpi, width, height = 40, wh[0], wh[1]
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize, font_gap = 80, 80, 5
fig = plt.figure(figsize=figsize)
if fig_title is not None:
fig.suptitle(
fig_title, fontsize=LegendFontsize, fontweight="bold", x=0.5, y=0.92
)
for idx, scatter_dict in enumerate(scatter_list):
cur_ax = fig.add_subplot(len(scatter_list), 1, idx + 1)
plot_scatter(
cur_ax,
scatter_dict["xaxis"],
scatter_dict["yaxis"],
scatter_dict["color"],
scatter_dict["alpha"],
scatter_dict["linewidths"],
scatter_dict["label"],
)
cur_ax.set_xlabel("X", fontsize=LabelSize)
cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize)
cur_ax.set_xlim(scatter_dict["xlim"][0], scatter_dict["xlim"][1])
cur_ax.set_ylim(scatter_dict["ylim"][0], scatter_dict["ylim"][1])
for tick in cur_ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
tick.label.set_rotation(10)
for tick in cur_ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
cur_ax.legend(loc=1, fontsize=LegendFontsize)
fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf")
fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png")
plt.close("all")
def find_min(cur, others):
if cur is None:
return float(others)
else:
return float(min(cur, others))
def find_max(cur, others):
if cur is None:
return float(others.max())
else:
return float(max(cur, others))
def compare_cl(save_dir):
save_dir = Path(str(save_dir))
save_dir.mkdir(parents=True, exist_ok=True)
dynamic_env, cl_function = create_example_v1(
# timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0),
timestamp_config=dict(num=200),
num_per_task=1000,
)
models = dict()
cl_function.set_timestamp(0)
cl_xaxis_min = None
cl_xaxis_max = None
all_data = OrderedDict()
for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)):
xaxis_all = dataset[0][:, 0].numpy()
yaxis_all = dataset[1][:, 0].numpy()
current_data = dict()
current_data["lfna_xaxis_all"] = xaxis_all
current_data["lfna_yaxis_all"] = yaxis_all
# compute cl-min
cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all.mean() - xaxis_all.std())
cl_xaxis_max = find_max(cl_xaxis_max, xaxis_all.mean() + xaxis_all.std())
all_data[timestamp] = current_data
global_cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.1)
global_cl_yaxis_all = cl_function.noise_call(global_cl_xaxis_all)
for idx, (timestamp, xdata) in enumerate(tqdm(all_data.items(), ncols=50)):
scatter_list = []
scatter_list.append(
{
"xaxis": xdata["lfna_xaxis_all"],
"yaxis": xdata["lfna_yaxis_all"],
"color": "k",
"linewidths": 15,
"alpha": 0.99,
"xlim": (-6, 6),
"ylim": (-40, 40),
"label": "LFNA",
}
)
cur_cl_xaxis_min = cl_xaxis_min
cur_cl_xaxis_max = cl_xaxis_min + (cl_xaxis_max - cl_xaxis_min) * (
idx + 1
) / len(all_data)
cl_xaxis_all = np.arange(cur_cl_xaxis_min, cur_cl_xaxis_max, step=0.01)
cl_yaxis_all = cl_function.noise_call(cl_xaxis_all, std=0.2)
scatter_list.append(
{
"xaxis": cl_xaxis_all,
"yaxis": cl_yaxis_all,
"color": "k",
"linewidths": 15,
"xlim": (round(cl_xaxis_min, 1), round(cl_xaxis_max, 1)),
"ylim": (-20, 6),
"alpha": 0.99,
"label": "Continual Learning",
}
)
draw_multi_fig(
save_dir,
idx,
scatter_list,
wh=(2200, 1800),
fig_title="Timestamp={:03d}".format(idx),
)
print("Save all figures into {:}".format(save_dir))
save_dir = save_dir.resolve()
base_cmd = (
"ffmpeg -y -i {xdir}/%04d.png -vf fps=1 -vf scale=2200:1800 -vb 5000k".format(
xdir=save_dir
)
)
video_cmd = "{:} -pix_fmt yuv420p {xdir}/compare-cl.mp4".format(
base_cmd, xdir=save_dir
)
print(video_cmd + "\n")
os.system(video_cmd)
os.system(
"{:} -pix_fmt yuv420p {xdir}/compare-cl.webm".format(base_cmd, xdir=save_dir)
)
def visualize_env(save_dir, version):
save_dir = Path(str(save_dir))
for substr in ("pdf", "png"):
sub_save_dir = save_dir / "{:}-{:}".format(substr, version)
sub_save_dir.mkdir(parents=True, exist_ok=True)
dynamic_env = get_synthetic_env(version=version)
print("env: {:}".format(dynamic_env))
print("oracle_map: {:}".format(dynamic_env.oracle_map))
allxs, allys = [], []
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
allxs.append(allx)
allys.append(ally)
if dynamic_env.meta_info["task"] == "regression":
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
print(
"x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item())
)
print(
"y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item())
)
elif dynamic_env.meta_info["task"] == "classification":
allxs = torch.cat(allxs)
print(
"x[0] - min={:.3f}, max={:.3f}".format(
allxs[:, 0].min().item(), allxs[:, 0].max().item()
)
)
print(
"x[1] - min={:.3f}, max={:.3f}".format(
allxs[:, 1].min().item(), allxs[:, 1].max().item()
)
)
else:
raise ValueError("Unknown task".format(dynamic_env.meta_info["task"]))
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
dpi, width, height = 30, 1800, 1400
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize, font_gap = 80, 80, 5
fig = plt.figure(figsize=figsize)
cur_ax = fig.add_subplot(1, 1, 1)
if dynamic_env.meta_info["task"] == "regression":
allx, ally = allx[:, 0].numpy(), ally[:, 0].numpy()
plot_scatter(
cur_ax, allx, ally, "k", 0.99, (15, 1.5), "timestamp={:05d}".format(idx)
)
cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1))
cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1))
elif dynamic_env.meta_info["task"] == "classification":
positive, negative = ally == 1, ally == 0
# plot_scatter(cur_ax, [1], [1], "k", 0.1, 1, "timestamp={:05d}".format(idx))
plot_scatter(
cur_ax,
allx[positive, 0],
allx[positive, 1],
"r",
0.99,
(20, 10),
"positive",
)
plot_scatter(
cur_ax,
allx[negative, 0],
allx[negative, 1],
"g",
0.99,
(20, 10),
"negative",
)
cur_ax.set_xlim(
round(allxs[:, 0].min().item(), 1), round(allxs[:, 0].max().item(), 1)
)
cur_ax.set_ylim(
round(allxs[:, 1].min().item(), 1), round(allxs[:, 1].max().item(), 1)
)
else:
raise ValueError("Unknown task".format(dynamic_env.meta_info["task"]))
cur_ax.set_xlabel("X", fontsize=LabelSize)
cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize)
for tick in cur_ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
tick.label.set_rotation(10)
for tick in cur_ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
cur_ax.legend(loc=1, fontsize=LegendFontsize)
pdf_save_path = (
save_dir
/ "pdf-{:}".format(version)
/ "v{:}-{:05d}.pdf".format(version, idx)
)
fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf")
png_save_path = (
save_dir
/ "png-{:}".format(version)
/ "v{:}-{:05d}.png".format(version, idx)
)
fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png")
plt.close("all")
save_dir = save_dir.resolve()
base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format(
xdir=save_dir / "png-{:}".format(version), version=version
)
print(base_cmd)
os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version))
os.system("{:} {xdir}/env-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version))
def compare_algs(save_dir, version, alg_dir="./outputs/GeMOSA-synthetic"):
save_dir = Path(str(save_dir))
for substr in ("pdf", "png"):
sub_save_dir = save_dir / substr
sub_save_dir.mkdir(parents=True, exist_ok=True)
dpi, width, height = 30, 3200, 2000
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize, font_gap = 80, 80, 5
dynamic_env = get_synthetic_env(mode=None, version=version)
allxs, allys = [], []
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
allxs.append(allx)
allys.append(ally)
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
alg_name2dir = OrderedDict()
# alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data"
# alg_name2dir["MAML"] = "use-maml-s1"
# alg_name2dir["LFNA (fix init)"] = "lfna-fix-init"
if version == "v1":
# alg_name2dir["Optimal"] = "use-same-timestamp"
alg_name2dir[
"GMOA"
] = "lfna-battle-bs128-d16_16_16-s16-lr0.002-wd1e-05-e10000-envv1"
else:
raise ValueError("Invalid version: {:}".format(version))
alg_name2all_containers = OrderedDict()
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
ckp_path = Path(alg_dir) / str(xdir) / "final-ckp.pth"
xdata = torch.load(ckp_path, map_location="cpu")
alg_name2all_containers[alg] = xdata["w_containers"]
# load the basic model
model = get_model(
dict(model_type="norm_mlp"),
input_dim=1,
output_dim=1,
hidden_dims=[16] * 2,
act_cls="gelu",
norm_cls="layer_norm_1d",
)
alg2xs, alg2ys = defaultdict(list), defaultdict(list)
colors = ["r", "g", "b", "m", "y"]
linewidths, skip = 10, 5
for idx, (timestamp, (ori_allx, ori_ally)) in enumerate(
tqdm(dynamic_env, ncols=50)
):
if idx <= skip:
continue
fig = plt.figure(figsize=figsize)
cur_ax = fig.add_subplot(2, 1, 1)
# the data
allx, ally = ori_allx[:, 0].numpy(), ori_ally[:, 0].numpy()
plot_scatter(cur_ax, allx, ally, "k", 0.99, linewidths, "Raw Data")
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
with torch.no_grad():
predicts = model.forward_with_container(
ori_allx, alg_name2all_containers[alg][idx]
)
predicts = predicts.cpu()
# keep data
metric = MSEMetric()
metric(predicts, ori_ally)
predicts = predicts.view(-1).numpy()
alg2xs[alg].append(idx)
alg2ys[alg].append(metric.get_info()["mse"])
plot_scatter(cur_ax, allx, predicts, colors[idx_alg], 0.99, linewidths, alg)
cur_ax.set_xlabel("X", fontsize=LabelSize)
cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize)
for tick in cur_ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
tick.label.set_rotation(10)
for tick in cur_ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1))
cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1))
cur_ax.legend(loc=1, fontsize=LegendFontsize)
# the trajectory data
cur_ax = fig.add_subplot(2, 1, 2)
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
# plot_scatter(cur_ax, alg2xs[alg], alg2ys[alg], olors[idx_alg], 0.99, linewidths, alg)
cur_ax.plot(
alg2xs[alg],
alg2ys[alg],
color=colors[idx_alg],
linestyle="-",
linewidth=5,
label=alg,
)
cur_ax.legend(loc=1, fontsize=LegendFontsize)
cur_ax.set_xlabel("Timestamp", fontsize=LabelSize)
cur_ax.set_ylabel("MSE", fontsize=LabelSize)
for tick in cur_ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
tick.label.set_rotation(10)
for tick in cur_ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
cur_ax.set_xlim(1, len(dynamic_env))
cur_ax.set_ylim(0, 10)
cur_ax.legend(loc=1, fontsize=LegendFontsize)
pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx - skip)
fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf")
png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx - skip)
fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png")
plt.close("all")
save_dir = save_dir.resolve()
base_cmd = "ffmpeg -y -i {xdir}/v{ver}-%05d.png -vf scale={w}:{h} -pix_fmt yuv420p -vb 5000k".format(
xdir=save_dir / "png", w=width, h=height, ver=version
)
os.system(
"{:} {xdir}/com-alg-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version)
)
os.system(
"{:} {xdir}/com-alg-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version)
)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Visualize synthetic data.")
parser.add_argument(
"--save_dir",
type=str,
default="./outputs/vis-synthetic",
help="The save directory.",
)
parser.add_argument(
"--env_version",
type=str,
required=True,
help="The synthetic enviornment version.",
)
args = parser.parse_args()
visualize_env(os.path.join(args.save_dir, "vis-env"), args.env_version)
# visualize_env(os.path.join(args.save_dir, "vis-env"), "v2")
# compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version)
# compare_cl(os.path.join(args.save_dir, "compare-cl"))

View File

@@ -0,0 +1,66 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
###########################################################################################################################################################
# Before run these commands, the files must be properly put.
#
# python exps/experimental/example-nas-bench.py --api_path $HOME/.torch/NAS-Bench-201-v1_1-096897.pth --archive_path $HOME/.torch/NAS-Bench-201-v1_1-archive
###########################################################################################################################################################
import os, gc, sys, math, argparse, psutil
import numpy as np
import torch
from pathlib import Path
from collections import OrderedDict
import matplotlib
import seaborn as sns
matplotlib.use("agg")
import matplotlib.pyplot as plt
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from nas_201_api import NASBench201API
from log_utils import time_string
from models import get_cell_based_tiny_net
from utils import weight_watcher
if __name__ == "__main__":
parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
parser.add_argument(
"--api_path",
type=str,
default=None,
help="The path to the NAS-Bench-201 benchmark file and weight dir.",
)
parser.add_argument(
"--archive_path",
type=str,
default=None,
help="The path to the NAS-Bench-201 weight dir.",
)
args = parser.parse_args()
meta_file = Path(args.api_path)
weight_dir = Path(args.archive_path)
assert meta_file.exists(), "invalid path for api : {:}".format(meta_file)
assert (
weight_dir.exists() and weight_dir.is_dir()
), "invalid path for weight dir : {:}".format(weight_dir)
api = NASBench201API(meta_file, verbose=True)
arch_index = 3 # query the 3-th architecture
api.reload(weight_dir, arch_index) # reload the data of 3-th from archive dir
data = "cifar10" # query the info from CIFAR-10
config = api.get_net_config(arch_index, data)
net = get_cell_based_tiny_net(config)
meta_info = api.query_meta_info_by_index(
arch_index, hp="200"
) # all info about this architecture
params = meta_info.get_net_param(data, 888)
net.load_state_dict(params)
_, summary = weight_watcher.analyze(net, alphas=False)
print("The summary of {:}-th architecture:\n{:}".format(arch_index, summary))

View File

@@ -0,0 +1,57 @@
from dks.base.activation_getter import (
get_activation_function as _get_numpy_activation_function,
)
from dks.base.activation_transform import _get_activations_params
def subnet_max_func(x, r_fn):
depth = 7
res_x = r_fn(x)
x = r_fn(x)
for _ in range(depth):
x = r_fn(r_fn(x)) + x
return max(x, res_x)
def subnet_max_func_v2(x, r_fn):
depth = 2
res_x = r_fn(x)
x = r_fn(x)
for _ in range(depth):
x = 0.8 * r_fn(r_fn(x)) + 0.2 * x
return max(x, res_x)
def get_transformed_activations(
activation_names,
method="TAT",
dks_params=None,
tat_params=None,
max_slope_func=None,
max_curv_func=None,
subnet_max_func=None,
activation_getter=_get_numpy_activation_function,
):
params = _get_activations_params(
activation_names,
method=method,
dks_params=dks_params,
tat_params=tat_params,
max_slope_func=max_slope_func,
max_curv_func=max_curv_func,
subnet_max_func=subnet_max_func,
)
return params
params = get_transformed_activations(
["swish"], method="TAT", subnet_max_func=subnet_max_func
)
print(params)
params = get_transformed_activations(
["leaky_relu"], method="TAT", subnet_max_func=subnet_max_func_v2
)
print(params)

View File

@@ -0,0 +1,21 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
#####################################################
# python test-dynamic.py
#####################################################
import sys
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
print("LIB-DIR: {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from xautodl.datasets.math_core import ConstantFunc
from xautodl.datasets.math_core import GaussianDGenerator
mean_generator = ConstantFunc(0)
cov_generator = ConstantFunc(1)
generator = GaussianDGenerator([mean_generator], [[cov_generator]], (-1, 1))
generator(0, 10)

View File

@@ -0,0 +1,28 @@
import sys, time, random, argparse
from copy import deepcopy
import torchvision.models as models
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from utils import get_model_infos
# from models.ImageNet_MobileNetV2 import MobileNetV2
from torchvision.models.mobilenet import MobileNetV2
def main(width_mult):
# model = MobileNetV2(1001, width_mult, 32, 1280, 'InvertedResidual', 0.2)
model = MobileNetV2(width_mult=width_mult)
print(model)
flops, params = get_model_infos(model, (2, 3, 224, 224))
print("FLOPs : {:}".format(flops))
print("Params : {:}".format(params))
print("-" * 50)
if __name__ == "__main__":
main(1.0)
main(1.4)

View File

@@ -0,0 +1,168 @@
# python ./exps/vis/test.py
import os, sys, random
from pathlib import Path
from copy import deepcopy
import torch
import numpy as np
from collections import OrderedDict
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from nas_201_api import NASBench201API as API
def test_nas_api():
from nas_201_api import ArchResults
xdata = torch.load(
"/home/dxy/FOR-RELEASE/NAS-Projects/output/NAS-BENCH-201-4/simplifies/architectures/000157-FULL.pth"
)
for key in ["full", "less"]:
print("\n------------------------- {:} -------------------------".format(key))
archRes = ArchResults.create_from_state_dict(xdata[key])
print(archRes)
print(archRes.arch_idx_str())
print(archRes.get_dataset_names())
print(archRes.get_comput_costs("cifar10-valid"))
# get the metrics
print(archRes.get_metrics("cifar10-valid", "x-valid", None, False))
print(archRes.get_metrics("cifar10-valid", "x-valid", None, True))
print(archRes.query("cifar10-valid", 777))
OPS = ["skip-connect", "conv-1x1", "conv-3x3", "pool-3x3"]
COLORS = ["chartreuse", "cyan", "navyblue", "chocolate1"]
def plot(filename):
from graphviz import Digraph
g = Digraph(
format="png",
edge_attr=dict(fontsize="20", fontname="times"),
node_attr=dict(
style="filled",
shape="rect",
align="center",
fontsize="20",
height="0.5",
width="0.5",
penwidth="2",
fontname="times",
),
engine="dot",
)
g.body.extend(["rankdir=LR"])
steps = 5
for i in range(0, steps):
if i == 0:
g.node(str(i), fillcolor="darkseagreen2")
elif i + 1 == steps:
g.node(str(i), fillcolor="palegoldenrod")
else:
g.node(str(i), fillcolor="lightblue")
for i in range(1, steps):
for xin in range(i):
op_i = random.randint(0, len(OPS) - 1)
# g.edge(str(xin), str(i), label=OPS[op_i], fillcolor=COLORS[op_i])
g.edge(
str(xin),
str(i),
label=OPS[op_i],
color=COLORS[op_i],
fillcolor=COLORS[op_i],
)
# import pdb; pdb.set_trace()
g.render(filename, cleanup=True, view=False)
def test_auto_grad():
class Net(torch.nn.Module):
def __init__(self, iS):
super(Net, self).__init__()
self.layer = torch.nn.Linear(iS, 1)
def forward(self, inputs):
outputs = self.layer(inputs)
outputs = torch.exp(outputs)
return outputs.mean()
net = Net(10)
inputs = torch.rand(256, 10)
loss = net(inputs)
first_order_grads = torch.autograd.grad(
loss, net.parameters(), retain_graph=True, create_graph=True
)
first_order_grads = torch.cat([x.view(-1) for x in first_order_grads])
second_order_grads = []
for grads in first_order_grads:
s_grads = torch.autograd.grad(grads, net.parameters())
second_order_grads.append(s_grads)
def test_one_shot_model(ckpath, use_train):
from models import get_cell_based_tiny_net, get_search_spaces
from datasets import get_datasets, SearchDataset
from config_utils import load_config, dict2config
from utils.nas_utils import evaluate_one_shot
use_train = int(use_train) > 0
# ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-11416-basic.pth'
# ckpath = 'output/search-cell-nas-bench-201/DARTS-V1-cifar10/checkpoint/seed-28640-basic.pth'
print("ckpath : {:}".format(ckpath))
ckp = torch.load(ckpath)
xargs = ckp["args"]
train_data, valid_data, xshape, class_num = get_datasets(
xargs.dataset, xargs.data_path, -1
)
# config = load_config(xargs.config_path, {'class_num': class_num, 'xshape': xshape}, None)
config = load_config(
"./configs/nas-benchmark/algos/DARTS.config",
{"class_num": class_num, "xshape": xshape},
None,
)
if xargs.dataset == "cifar10":
cifar_split = load_config("configs/nas-benchmark/cifar-split.txt", None, None)
xvalid_data = deepcopy(train_data)
xvalid_data.transform = valid_data.transform
valid_loader = torch.utils.data.DataLoader(
xvalid_data,
batch_size=2048,
sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar_split.valid),
num_workers=12,
pin_memory=True,
)
else:
raise ValueError("invalid dataset : {:}".format(xargs.dataseet))
search_space = get_search_spaces("cell", xargs.search_space_name)
model_config = dict2config(
{
"name": "SETN",
"C": xargs.channel,
"N": xargs.num_cells,
"max_nodes": xargs.max_nodes,
"num_classes": class_num,
"space": search_space,
"affine": False,
"track_running_stats": True,
},
None,
)
search_model = get_cell_based_tiny_net(model_config)
search_model.load_state_dict(ckp["search_model"])
search_model = search_model.cuda()
api = API("/home/dxy/.torch/NAS-Bench-201-v1_0-e61699.pth")
archs, probs, accuracies = evaluate_one_shot(
search_model, valid_loader, api, use_train
)
if __name__ == "__main__":
# test_nas_api()
# for i in range(200): plot('{:04d}'.format(i))
# test_auto_grad()
test_one_shot_model(sys.argv[1], sys.argv[2])

View File

@@ -0,0 +1,31 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
#####################################################
# python exps/experimental/test-resnest.py
#####################################################
import sys, time, torch, random, argparse
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from utils import get_model_infos
torch.hub.list("zhanghang1989/ResNeSt", force_reload=True)
for model_name, xshape in [
("resnest50", (1, 3, 224, 224)),
("resnest101", (1, 3, 256, 256)),
("resnest200", (1, 3, 320, 320)),
("resnest269", (1, 3, 416, 416)),
]:
# net = torch.hub.load('zhanghang1989/ResNeSt', model_name, pretrained=True)
net = torch.hub.load("zhanghang1989/ResNeSt", model_name, pretrained=False)
print("Model : {:}, input shape : {:}".format(model_name, xshape))
flops, param = get_model_infos(net, xshape)
print("flops : {:.3f}M".format(flops))
print("params : {:.3f}M".format(param))

View File

@@ -0,0 +1,198 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
###########################################################################################################################################################
# Before run these commands, the files must be properly put.
#
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NATS-tss-v1_0-3ffb9 --dataset cifar10
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NATS-sss-v1_0-50262 --dataset cifar100
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space sss --base_path $HOME/.torch/NATS-sss-v1_0-50262 --dataset ImageNet16-120
# CUDA_VISIBLE_DEVICES='' OMP_NUM_THREADS=4 python exps/experimental/test-ww-bench.py --search_space tss --base_path $HOME/.torch/NATS-tss-v1_0-3ffb9 --dataset cifar10
###########################################################################################################################################################
import os, gc, sys, math, argparse, psutil
import numpy as np
import torch
from pathlib import Path
from collections import OrderedDict
import matplotlib
import seaborn as sns
matplotlib.use("agg")
import matplotlib.pyplot as plt
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
print("LIB-DIR: {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from log_utils import time_string
from nats_bench import create
from models import get_cell_based_tiny_net
from utils import weight_watcher
"""
def get_cor(A, B):
return float(np.corrcoef(A, B)[0,1])
def tostr(accdict, norms):
xstr = []
for key, accs in accdict.items():
cor = get_cor(accs, norms)
xstr.append('{:}: {:.3f}'.format(key, cor))
return ' '.join(xstr)
"""
def evaluate(api, weight_dir, data: str):
print("\nEvaluate dataset={:}".format(data))
process = psutil.Process(os.getpid())
norms, accuracies = [], []
ok, total = 0, 5000
for idx in range(total):
arch_index = api.random()
api.reload(weight_dir, arch_index)
# compute the weight watcher results
config = api.get_net_config(arch_index, data)
net = get_cell_based_tiny_net(config)
meta_info = api.query_meta_info_by_index(
arch_index, hp="200" if api.search_space_name == "topology" else "90"
)
params = meta_info.get_net_param(
data, 888 if api.search_space_name == "topology" else 777
)
with torch.no_grad():
net.load_state_dict(params)
_, summary = weight_watcher.analyze(net, alphas=False)
if "lognorm" not in summary:
api.clear_params(arch_index, None)
del net
continue
continue
cur_norm = -summary["lognorm"]
api.clear_params(arch_index, None)
if math.isnan(cur_norm):
del net, meta_info
continue
else:
ok += 1
norms.append(cur_norm)
# query the accuracy
info = meta_info.get_metrics(
data,
"ori-test",
iepoch=None,
is_random=888 if api.search_space_name == "topology" else 777,
)
accuracies.append(info["accuracy"])
del net, meta_info
# print the information
if idx % 20 == 0:
gc.collect()
print(
"{:} {:04d}_{:04d}/{:04d} ({:.2f} MB memory)".format(
time_string(), ok, idx, total, process.memory_info().rss / 1e6
)
)
return norms, accuracies
def main(search_space, meta_file: str, weight_dir, save_dir, xdata):
save_dir.mkdir(parents=True, exist_ok=True)
api = create(meta_file, search_space, verbose=False)
datasets = ["cifar10-valid", "cifar10", "cifar100", "ImageNet16-120"]
print(time_string() + " " + "=" * 50)
for data in datasets:
hps = api.avaliable_hps
for hp in hps:
nums = api.statistics(data, hp=hp)
total = sum([k * v for k, v in nums.items()])
print(
"Using {:3s} epochs, trained on {:20s} : {:} trials in total ({:}).".format(
hp, data, total, nums
)
)
print(time_string() + " " + "=" * 50)
norms, accuracies = evaluate(api, weight_dir, xdata)
indexes = list(range(len(norms)))
norm_indexes = sorted(indexes, key=lambda i: norms[i])
accy_indexes = sorted(indexes, key=lambda i: accuracies[i])
labels = []
for index in norm_indexes:
labels.append(accy_indexes.index(index))
dpi, width, height = 200, 1400, 800
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 18, 12
resnet_scale, resnet_alpha = 120, 0.5
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xlim(min(indexes), max(indexes))
plt.ylim(min(indexes), max(indexes))
# plt.ylabel('y').set_rotation(30)
plt.yticks(
np.arange(min(indexes), max(indexes), max(indexes) // 3),
fontsize=LegendFontsize,
rotation="vertical",
)
plt.xticks(
np.arange(min(indexes), max(indexes), max(indexes) // 5),
fontsize=LegendFontsize,
)
ax.scatter(indexes, labels, marker="*", s=0.5, c="tab:red", alpha=0.8)
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="Test accuracy")
ax.scatter([-1], [-1], marker="*", s=100, c="tab:red", label="Weight watcher")
plt.grid(zorder=0)
ax.set_axisbelow(True)
plt.legend(loc=0, fontsize=LegendFontsize)
ax.set_xlabel(
"architecture ranking sorted by the test accuracy ", fontsize=LabelSize
)
ax.set_ylabel("architecture ranking computed by weight watcher", fontsize=LabelSize)
save_path = (save_dir / "{:}-{:}-test-ww.pdf".format(search_space, xdata)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
save_path = (save_dir / "{:}-{:}-test-ww.png".format(search_space, xdata)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
print("{:} finish this test.".format(time_string()))
if __name__ == "__main__":
parser = argparse.ArgumentParser("Analysis of NAS-Bench-201")
parser.add_argument(
"--save_dir",
type=str,
default="./output/vis-nas-bench/",
help="The base-name of folder to save checkpoints and log.",
)
parser.add_argument(
"--search_space",
type=str,
default=None,
choices=["tss", "sss"],
help="The search space.",
)
parser.add_argument(
"--base_path",
type=str,
default=None,
help="The path to the NAS-Bench-201 benchmark file and weight dir.",
)
parser.add_argument("--dataset", type=str, default=None, help=".")
args = parser.parse_args()
save_dir = Path(args.save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
meta_file = Path(args.base_path + ".pth")
weight_dir = Path(args.base_path + "-full")
assert meta_file.exists(), "invalid path for api : {:}".format(meta_file)
assert (
weight_dir.exists() and weight_dir.is_dir()
), "invalid path for weight dir : {:}".format(weight_dir)
main(args.search_space, str(meta_file), weight_dir, save_dir, args.dataset)

View File

@@ -0,0 +1,30 @@
import sys, time, random, argparse
from copy import deepcopy
import torchvision.models as models
from pathlib import Path
from xautodl.utils import weight_watcher
def main():
# model = models.vgg19_bn(pretrained=True)
# _, summary = weight_watcher.analyze(model, alphas=False)
# for key, value in summary.items():
# print('{:10s} : {:}'.format(key, value))
_, summary = weight_watcher.analyze(models.vgg13(pretrained=True), alphas=False)
print("vgg-13 : {:}".format(summary["lognorm"]))
_, summary = weight_watcher.analyze(models.vgg13_bn(pretrained=True), alphas=False)
print("vgg-13-BN : {:}".format(summary["lognorm"]))
_, summary = weight_watcher.analyze(models.vgg16(pretrained=True), alphas=False)
print("vgg-16 : {:}".format(summary["lognorm"]))
_, summary = weight_watcher.analyze(models.vgg16_bn(pretrained=True), alphas=False)
print("vgg-16-BN : {:}".format(summary["lognorm"]))
_, summary = weight_watcher.analyze(models.vgg19(pretrained=True), alphas=False)
print("vgg-19 : {:}".format(summary["lognorm"]))
_, summary = weight_watcher.analyze(models.vgg19_bn(pretrained=True), alphas=False)
print("vgg-19-BN : {:}".format(summary["lognorm"]))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,178 @@
###############################################################
# NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
###############################################################
# Usage: python exps/experimental/vis-nats-bench-algos.py --search_space tss
# Usage: python exps/experimental/vis-nats-bench-algos.py --search_space sss
###############################################################
import os, gc, sys, time, torch, argparse
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict, OrderedDict
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from nats_bench import create
from log_utils import time_string
def fetch_data(root_dir="./output/search", search_space="tss", dataset=None):
ss_dir = "{:}-{:}".format(root_dir, search_space)
alg2name, alg2path = OrderedDict(), OrderedDict()
alg2name["REA"] = "R-EA-SS3"
alg2name["REINFORCE"] = "REINFORCE-0.01"
alg2name["RANDOM"] = "RANDOM"
alg2name["BOHB"] = "BOHB"
for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, "results.pth")
assert os.path.isfile(alg2path[alg]), "invalid path : {:}".format(alg2path[alg])
alg2data = OrderedDict()
for alg, path in alg2path.items():
data = torch.load(path)
for index, info in data.items():
info["time_w_arch"] = [
(x, y) for x, y in zip(info["all_total_times"], info["all_archs"])
]
for j, arch in enumerate(info["all_archs"]):
assert arch != -1, "invalid arch from {:} {:} {:} ({:}, {:})".format(
alg, search_space, dataset, index, j
)
alg2data[alg] = data
return alg2data
def query_performance(api, data, dataset, ticket):
results, is_size_space = [], api.search_space_name == "size"
for i, info in data.items():
time_w_arch = sorted(info["time_w_arch"], key=lambda x: abs(x[0] - ticket))
time_a, arch_a = time_w_arch[0]
time_b, arch_b = time_w_arch[1]
info_a = api.get_more_info(
arch_a, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
info_b = api.get_more_info(
arch_b, dataset=dataset, hp=90 if is_size_space else 200, is_random=False
)
accuracy_a, accuracy_b = info_a["test-accuracy"], info_b["test-accuracy"]
interplate = (time_b - ticket) / (time_b - time_a) * accuracy_a + (
ticket - time_a
) / (time_b - time_a) * accuracy_b
results.append(interplate)
return sum(results) / len(results)
y_min_s = {
("cifar10", "tss"): 90,
("cifar10", "sss"): 92,
("cifar100", "tss"): 65,
("cifar100", "sss"): 65,
("ImageNet16-120", "tss"): 36,
("ImageNet16-120", "sss"): 40,
}
y_max_s = {
("cifar10", "tss"): 94.5,
("cifar10", "sss"): 93.3,
("cifar100", "tss"): 72,
("cifar100", "sss"): 70,
("ImageNet16-120", "tss"): 44,
("ImageNet16-120", "sss"): 46,
}
name2label = {
"cifar10": "CIFAR-10",
"cifar100": "CIFAR-100",
"ImageNet16-120": "ImageNet-16-120",
}
def visualize_curve(api, vis_save_dir, search_space, max_time):
vis_save_dir = vis_save_dir.resolve()
vis_save_dir.mkdir(parents=True, exist_ok=True)
dpi, width, height = 250, 5200, 1400
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 16, 16
def sub_plot_fn(ax, dataset):
alg2data = fetch_data(search_space=search_space, dataset=dataset)
alg2accuracies = OrderedDict()
total_tickets = 150
time_tickets = [
float(i) / total_tickets * max_time for i in range(total_tickets)
]
colors = ["b", "g", "c", "m", "y"]
ax.set_xlim(0, 200)
ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)])
for idx, (alg, data) in enumerate(alg2data.items()):
print("plot alg : {:}".format(alg))
accuracies = []
for ticket in time_tickets:
accuracy = query_performance(api, data, dataset, ticket)
accuracies.append(accuracy)
alg2accuracies[alg] = accuracies
ax.plot(
[x / 100 for x in time_tickets],
accuracies,
c=colors[idx],
label="{:}".format(alg),
)
ax.set_xlabel("Estimated wall-clock time (1e2 seconds)", fontsize=LabelSize)
ax.set_ylabel(
"Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize
)
ax.set_title(
"Searching results on {:}".format(name2label[dataset]),
fontsize=LabelSize + 4,
)
ax.legend(loc=4, fontsize=LegendFontsize)
fig, axs = plt.subplots(1, 3, figsize=figsize)
datasets = ["cifar10", "cifar100", "ImageNet16-120"]
for dataset, ax in zip(datasets, axs):
sub_plot_fn(ax, dataset)
print("sub-plot {:} on {:} done.".format(dataset, search_space))
save_path = (vis_save_dir / "{:}-curve.png".format(search_space)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NAS-Bench-X",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--save_dir",
type=str,
default="output/vis-nas-bench/nas-algos",
help="Folder to save checkpoints and log.",
)
parser.add_argument(
"--search_space",
type=str,
choices=["tss", "sss"],
help="Choose the search space.",
)
parser.add_argument(
"--max_time", type=float, default=20000, help="The maximum time budget."
)
args = parser.parse_args()
save_dir = Path(args.save_dir)
api = create(None, args.search_space, verbose=False)
visualize_curve(api, save_dir, args.search_space, args.max_time)

View File

@@ -0,0 +1,185 @@
###############################################################
# NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
###############################################################
# Usage: python exps/experimental/vis-nats-bench-ws.py --search_space tss
# Usage: python exps/experimental/vis-nats-bench-ws.py --search_space sss
###############################################################
import os, gc, sys, time, torch, argparse
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict, OrderedDict
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from nats_bench import create
from log_utils import time_string
# def fetch_data(root_dir='./output/search', search_space='tss', dataset=None, suffix='-WARMNone'):
def fetch_data(
root_dir="./output/search", search_space="tss", dataset=None, suffix="-WARM0.3"
):
ss_dir = "{:}-{:}".format(root_dir, search_space)
alg2name, alg2path = OrderedDict(), OrderedDict()
seeds = [777, 888, 999]
print("\n[fetch data] from {:} on {:}".format(search_space, dataset))
if search_space == "tss":
alg2name["GDAS"] = "gdas-affine0_BN0-None"
alg2name["RSPS"] = "random-affine0_BN0-None"
alg2name["DARTS (1st)"] = "darts-v1-affine0_BN0-None"
alg2name["DARTS (2nd)"] = "darts-v2-affine0_BN0-None"
alg2name["ENAS"] = "enas-affine0_BN0-None"
alg2name["SETN"] = "setn-affine0_BN0-None"
else:
# alg2name['TAS'] = 'tas-affine0_BN0{:}'.format(suffix)
# alg2name['FBNetV2'] = 'fbv2-affine0_BN0{:}'.format(suffix)
# alg2name['TuNAS'] = 'tunas-affine0_BN0{:}'.format(suffix)
alg2name["channel-wise interpolation"] = "tas-affine0_BN0-AWD0.001{:}".format(
suffix
)
alg2name[
"masking + Gumbel-Softmax"
] = "mask_gumbel-affine0_BN0-AWD0.001{:}".format(suffix)
alg2name["masking + sampling"] = "mask_rl-affine0_BN0-AWD0.0{:}".format(suffix)
for alg, name in alg2name.items():
alg2path[alg] = os.path.join(ss_dir, dataset, name, "seed-{:}-last-info.pth")
alg2data = OrderedDict()
for alg, path in alg2path.items():
alg2data[alg], ok_num = [], 0
for seed in seeds:
xpath = path.format(seed)
if os.path.isfile(xpath):
ok_num += 1
else:
print("This is an invalid path : {:}".format(xpath))
continue
data = torch.load(xpath, map_location=torch.device("cpu"))
data = torch.load(data["last_checkpoint"], map_location=torch.device("cpu"))
alg2data[alg].append(data["genotypes"])
print("This algorithm : {:} has {:} valid ckps.".format(alg, ok_num))
assert ok_num > 0, "Must have at least 1 valid ckps."
return alg2data
y_min_s = {
("cifar10", "tss"): 90,
("cifar10", "sss"): 92,
("cifar100", "tss"): 65,
("cifar100", "sss"): 65,
("ImageNet16-120", "tss"): 36,
("ImageNet16-120", "sss"): 40,
}
y_max_s = {
("cifar10", "tss"): 94.5,
("cifar10", "sss"): 93.3,
("cifar100", "tss"): 72,
("cifar100", "sss"): 70,
("ImageNet16-120", "tss"): 44,
("ImageNet16-120", "sss"): 46,
}
name2label = {
"cifar10": "CIFAR-10",
"cifar100": "CIFAR-100",
"ImageNet16-120": "ImageNet-16-120",
}
def visualize_curve(api, vis_save_dir, search_space):
vis_save_dir = vis_save_dir.resolve()
vis_save_dir.mkdir(parents=True, exist_ok=True)
dpi, width, height = 250, 5200, 1400
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 16, 16
def sub_plot_fn(ax, dataset):
alg2data = fetch_data(search_space=search_space, dataset=dataset)
alg2accuracies = OrderedDict()
epochs = 100
colors = ["b", "g", "c", "m", "y", "r"]
ax.set_xlim(0, epochs)
# ax.set_ylim(y_min_s[(dataset, search_space)], y_max_s[(dataset, search_space)])
for idx, (alg, data) in enumerate(alg2data.items()):
print("plot alg : {:}".format(alg))
xs, accuracies = [], []
for iepoch in range(epochs + 1):
try:
structures, accs = [_[iepoch - 1] for _ in data], []
except:
raise ValueError(
"This alg {:} on {:} has invalid checkpoints.".format(
alg, dataset
)
)
for structure in structures:
info = api.get_more_info(
structure,
dataset=dataset,
hp=90 if api.search_space_name == "size" else 200,
is_random=False,
)
accs.append(info["test-accuracy"])
accuracies.append(sum(accs) / len(accs))
xs.append(iepoch)
alg2accuracies[alg] = accuracies
ax.plot(xs, accuracies, c=colors[idx], label="{:}".format(alg))
ax.set_xlabel("The searching epoch", fontsize=LabelSize)
ax.set_ylabel(
"Test accuracy on {:}".format(name2label[dataset]), fontsize=LabelSize
)
ax.set_title(
"Searching results on {:}".format(name2label[dataset]),
fontsize=LabelSize + 4,
)
ax.legend(loc=4, fontsize=LegendFontsize)
fig, axs = plt.subplots(1, 3, figsize=figsize)
datasets = ["cifar10", "cifar100", "ImageNet16-120"]
for dataset, ax in zip(datasets, axs):
sub_plot_fn(ax, dataset)
print("sub-plot {:} on {:} done.".format(dataset, search_space))
save_path = (vis_save_dir / "{:}-ws-curve.png".format(search_space)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NAS-Bench-X",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--save_dir",
type=str,
default="output/vis-nas-bench/nas-algos",
help="Folder to save checkpoints and log.",
)
parser.add_argument(
"--search_space",
type=str,
default="tss",
choices=["tss", "sss"],
help="Choose the search space.",
)
args = parser.parse_args()
save_dir = Path(args.save_dir)
api = create(None, args.search_space, fast_mode=True, verbose=False)
visualize_curve(api, save_dir, args.search_space)

View File

@@ -0,0 +1,657 @@
###############################################################
# NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
###############################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
###############################################################
# Usage: python exps/experimental/visualize-nas-bench-x.py
###############################################################
import os, sys, time, torch, argparse
import numpy as np
from typing import List, Text, Dict, Any
from shutil import copyfile
from collections import defaultdict
from copy import deepcopy
from pathlib import Path
import matplotlib
import seaborn as sns
matplotlib.use("agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import dict2config, load_config
from log_utils import time_string
from models import get_cell_based_tiny_net
from nats_bench import create
def visualize_info(api, vis_save_dir, indicator):
vis_save_dir = vis_save_dir.resolve()
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar10", indicator
)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar100", indicator
)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"ImageNet16-120", indicator
)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
indexes = list(range(len(cifar010_info["params"])))
print("{:} start to visualize relative ranking".format(time_string()))
cifar010_ord_indexes = sorted(indexes, key=lambda i: cifar010_info["test_accs"][i])
cifar100_ord_indexes = sorted(indexes, key=lambda i: cifar100_info["test_accs"][i])
imagenet_ord_indexes = sorted(indexes, key=lambda i: imagenet_info["test_accs"][i])
cifar100_labels, imagenet_labels = [], []
for idx in cifar010_ord_indexes:
cifar100_labels.append(cifar100_ord_indexes.index(idx))
imagenet_labels.append(imagenet_ord_indexes.index(idx))
print("{:} prepare data done.".format(time_string()))
dpi, width, height = 200, 1400, 800
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 18, 12
resnet_scale, resnet_alpha = 120, 0.5
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111)
plt.xlim(min(indexes), max(indexes))
plt.ylim(min(indexes), max(indexes))
# plt.ylabel('y').set_rotation(30)
plt.yticks(
np.arange(min(indexes), max(indexes), max(indexes) // 3),
fontsize=LegendFontsize,
rotation="vertical",
)
plt.xticks(
np.arange(min(indexes), max(indexes), max(indexes) // 5),
fontsize=LegendFontsize,
)
ax.scatter(indexes, cifar100_labels, marker="^", s=0.5, c="tab:green", alpha=0.8)
ax.scatter(indexes, imagenet_labels, marker="*", s=0.5, c="tab:red", alpha=0.8)
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
ax.scatter([-1], [-1], marker="o", s=100, c="tab:blue", label="CIFAR-10")
ax.scatter([-1], [-1], marker="^", s=100, c="tab:green", label="CIFAR-100")
ax.scatter([-1], [-1], marker="*", s=100, c="tab:red", label="ImageNet-16-120")
plt.grid(zorder=0)
ax.set_axisbelow(True)
plt.legend(loc=0, fontsize=LegendFontsize)
ax.set_xlabel("architecture ranking in CIFAR-10", fontsize=LabelSize)
ax.set_ylabel("architecture ranking", fontsize=LabelSize)
save_path = (vis_save_dir / "{:}-relative-rank.pdf".format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
save_path = (vis_save_dir / "{:}-relative-rank.png".format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
def visualize_sss_info(api, dataset, vis_save_dir):
vis_save_dir = vis_save_dir.resolve()
print("{:} start to visualize {:} information".format(time_string(), dataset))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cache_file_path = vis_save_dir / "{:}-cache-sss-info.pth".format(dataset)
if not cache_file_path.exists():
print("Do not find cache file : {:}".format(cache_file_path))
params, flops, train_accs, valid_accs, test_accs = [], [], [], [], []
for index in range(len(api)):
cost_info = api.get_cost_info(index, dataset, hp="90")
params.append(cost_info["params"])
flops.append(cost_info["flops"])
# accuracy
info = api.get_more_info(index, dataset, hp="90", is_random=False)
train_accs.append(info["train-accuracy"])
test_accs.append(info["test-accuracy"])
if dataset == "cifar10":
info = api.get_more_info(
index, "cifar10-valid", hp="90", is_random=False
)
valid_accs.append(info["valid-accuracy"])
else:
valid_accs.append(info["valid-accuracy"])
info = {
"params": params,
"flops": flops,
"train_accs": train_accs,
"valid_accs": valid_accs,
"test_accs": test_accs,
}
torch.save(info, cache_file_path)
else:
print("Find cache file : {:}".format(cache_file_path))
info = torch.load(cache_file_path)
params, flops, train_accs, valid_accs, test_accs = (
info["params"],
info["flops"],
info["train_accs"],
info["valid_accs"],
info["test_accs"],
)
print("{:} collect data done.".format(time_string()))
pyramid = [
"8:16:32:48:64",
"8:8:16:32:48",
"8:8:16:16:32",
"8:8:16:16:48",
"8:8:16:16:64",
"16:16:32:32:64",
"32:32:64:64:64",
]
pyramid_indexes = [api.query_index_by_arch(x) for x in pyramid]
largest_indexes = [api.query_index_by_arch("64:64:64:64:64")]
indexes = list(range(len(params)))
dpi, width, height = 250, 8500, 1300
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 24, 24
# resnet_scale, resnet_alpha = 120, 0.5
xscale, xalpha = 120, 0.8
fig, axs = plt.subplots(1, 4, figsize=figsize)
# ax1, ax2, ax3, ax4, ax5 = axs
for ax in axs:
for tick in ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.0f"))
for tick in ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
ax2, ax3, ax4, ax5 = axs
# ax1.xaxis.set_ticks(np.arange(0, max(indexes), max(indexes)//5))
# ax1.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue')
# ax1.set_xlabel('architecture ID', fontsize=LabelSize)
# ax1.set_ylabel('test accuracy (%)', fontsize=LabelSize)
ax2.scatter(params, train_accs, marker="o", s=0.5, c="tab:blue")
ax2.scatter(
[params[x] for x in pyramid_indexes],
[train_accs[x] for x in pyramid_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="Pyramid Structure",
alpha=xalpha,
)
ax2.scatter(
[params[x] for x in largest_indexes],
[train_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax2.set_xlabel("#parameters (MB)", fontsize=LabelSize)
ax2.set_ylabel("train accuracy (%)", fontsize=LabelSize)
ax2.legend(loc=4, fontsize=LegendFontsize)
ax3.scatter(params, test_accs, marker="o", s=0.5, c="tab:blue")
ax3.scatter(
[params[x] for x in pyramid_indexes],
[test_accs[x] for x in pyramid_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="Pyramid Structure",
alpha=xalpha,
)
ax3.scatter(
[params[x] for x in largest_indexes],
[test_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax3.set_xlabel("#parameters (MB)", fontsize=LabelSize)
ax3.set_ylabel("test accuracy (%)", fontsize=LabelSize)
ax3.legend(loc=4, fontsize=LegendFontsize)
ax4.scatter(flops, train_accs, marker="o", s=0.5, c="tab:blue")
ax4.scatter(
[flops[x] for x in pyramid_indexes],
[train_accs[x] for x in pyramid_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="Pyramid Structure",
alpha=xalpha,
)
ax4.scatter(
[flops[x] for x in largest_indexes],
[train_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax4.set_xlabel("#FLOPs (M)", fontsize=LabelSize)
ax4.set_ylabel("train accuracy (%)", fontsize=LabelSize)
ax4.legend(loc=4, fontsize=LegendFontsize)
ax5.scatter(flops, test_accs, marker="o", s=0.5, c="tab:blue")
ax5.scatter(
[flops[x] for x in pyramid_indexes],
[test_accs[x] for x in pyramid_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="Pyramid Structure",
alpha=xalpha,
)
ax5.scatter(
[flops[x] for x in largest_indexes],
[test_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax5.set_xlabel("#FLOPs (M)", fontsize=LabelSize)
ax5.set_ylabel("test accuracy (%)", fontsize=LabelSize)
ax5.legend(loc=4, fontsize=LegendFontsize)
save_path = vis_save_dir / "sss-{:}.png".format(dataset)
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
def visualize_tss_info(api, dataset, vis_save_dir):
vis_save_dir = vis_save_dir.resolve()
print("{:} start to visualize {:} information".format(time_string(), dataset))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cache_file_path = vis_save_dir / "{:}-cache-tss-info.pth".format(dataset)
if not cache_file_path.exists():
print("Do not find cache file : {:}".format(cache_file_path))
params, flops, train_accs, valid_accs, test_accs = [], [], [], [], []
for index in range(len(api)):
cost_info = api.get_cost_info(index, dataset, hp="12")
params.append(cost_info["params"])
flops.append(cost_info["flops"])
# accuracy
info = api.get_more_info(index, dataset, hp="200", is_random=False)
train_accs.append(info["train-accuracy"])
test_accs.append(info["test-accuracy"])
if dataset == "cifar10":
info = api.get_more_info(
index, "cifar10-valid", hp="200", is_random=False
)
valid_accs.append(info["valid-accuracy"])
else:
valid_accs.append(info["valid-accuracy"])
print("")
info = {
"params": params,
"flops": flops,
"train_accs": train_accs,
"valid_accs": valid_accs,
"test_accs": test_accs,
}
torch.save(info, cache_file_path)
else:
print("Find cache file : {:}".format(cache_file_path))
info = torch.load(cache_file_path)
params, flops, train_accs, valid_accs, test_accs = (
info["params"],
info["flops"],
info["train_accs"],
info["valid_accs"],
info["test_accs"],
)
print("{:} collect data done.".format(time_string()))
resnet = [
"|nor_conv_3x3~0|+|none~0|nor_conv_3x3~1|+|skip_connect~0|none~1|skip_connect~2|"
]
resnet_indexes = [api.query_index_by_arch(x) for x in resnet]
largest_indexes = [
api.query_index_by_arch(
"|nor_conv_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|nor_conv_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|"
)
]
indexes = list(range(len(params)))
dpi, width, height = 250, 8500, 1300
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 24, 24
# resnet_scale, resnet_alpha = 120, 0.5
xscale, xalpha = 120, 0.8
fig, axs = plt.subplots(1, 4, figsize=figsize)
# ax1, ax2, ax3, ax4, ax5 = axs
for ax in axs:
for tick in ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%.0f"))
for tick in ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
ax2, ax3, ax4, ax5 = axs
# ax1.xaxis.set_ticks(np.arange(0, max(indexes), max(indexes)//5))
# ax1.scatter(indexes, test_accs, marker='o', s=0.5, c='tab:blue')
# ax1.set_xlabel('architecture ID', fontsize=LabelSize)
# ax1.set_ylabel('test accuracy (%)', fontsize=LabelSize)
ax2.scatter(params, train_accs, marker="o", s=0.5, c="tab:blue")
ax2.scatter(
[params[x] for x in resnet_indexes],
[train_accs[x] for x in resnet_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="ResNet",
alpha=xalpha,
)
ax2.scatter(
[params[x] for x in largest_indexes],
[train_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax2.set_xlabel("#parameters (MB)", fontsize=LabelSize)
ax2.set_ylabel("train accuracy (%)", fontsize=LabelSize)
ax2.legend(loc=4, fontsize=LegendFontsize)
ax3.scatter(params, test_accs, marker="o", s=0.5, c="tab:blue")
ax3.scatter(
[params[x] for x in resnet_indexes],
[test_accs[x] for x in resnet_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="ResNet",
alpha=xalpha,
)
ax3.scatter(
[params[x] for x in largest_indexes],
[test_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax3.set_xlabel("#parameters (MB)", fontsize=LabelSize)
ax3.set_ylabel("test accuracy (%)", fontsize=LabelSize)
ax3.legend(loc=4, fontsize=LegendFontsize)
ax4.scatter(flops, train_accs, marker="o", s=0.5, c="tab:blue")
ax4.scatter(
[flops[x] for x in resnet_indexes],
[train_accs[x] for x in resnet_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="ResNet",
alpha=xalpha,
)
ax4.scatter(
[flops[x] for x in largest_indexes],
[train_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax4.set_xlabel("#FLOPs (M)", fontsize=LabelSize)
ax4.set_ylabel("train accuracy (%)", fontsize=LabelSize)
ax4.legend(loc=4, fontsize=LegendFontsize)
ax5.scatter(flops, test_accs, marker="o", s=0.5, c="tab:blue")
ax5.scatter(
[flops[x] for x in resnet_indexes],
[test_accs[x] for x in resnet_indexes],
marker="*",
s=xscale,
c="tab:orange",
label="ResNet",
alpha=xalpha,
)
ax5.scatter(
[flops[x] for x in largest_indexes],
[test_accs[x] for x in largest_indexes],
marker="x",
s=xscale,
c="tab:green",
label="Largest Candidate",
alpha=xalpha,
)
ax5.set_xlabel("#FLOPs (M)", fontsize=LabelSize)
ax5.set_ylabel("test accuracy (%)", fontsize=LabelSize)
ax5.legend(loc=4, fontsize=LegendFontsize)
save_path = vis_save_dir / "tss-{:}.png".format(dataset)
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
def visualize_rank_info(api, vis_save_dir, indicator):
vis_save_dir = vis_save_dir.resolve()
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar10", indicator
)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar100", indicator
)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"ImageNet16-120", indicator
)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
indexes = list(range(len(cifar010_info["params"])))
print("{:} start to visualize relative ranking".format(time_string()))
dpi, width, height = 250, 3800, 1200
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 14, 14
fig, axs = plt.subplots(1, 3, figsize=figsize)
ax1, ax2, ax3 = axs
def get_labels(info):
ord_test_indexes = sorted(indexes, key=lambda i: info["test_accs"][i])
ord_valid_indexes = sorted(indexes, key=lambda i: info["valid_accs"][i])
labels = []
for idx in ord_test_indexes:
labels.append(ord_valid_indexes.index(idx))
return labels
def plot_ax(labels, ax, name):
for tick in ax.xaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
for tick in ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize)
tick.label.set_rotation(90)
ax.set_xlim(min(indexes), max(indexes))
ax.set_ylim(min(indexes), max(indexes))
ax.yaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 3))
ax.xaxis.set_ticks(np.arange(min(indexes), max(indexes), max(indexes) // 5))
ax.scatter(indexes, labels, marker="^", s=0.5, c="tab:green", alpha=0.8)
ax.scatter(indexes, indexes, marker="o", s=0.5, c="tab:blue", alpha=0.8)
ax.scatter(
[-1], [-1], marker="^", s=100, c="tab:green", label="{:} test".format(name)
)
ax.scatter(
[-1],
[-1],
marker="o",
s=100,
c="tab:blue",
label="{:} validation".format(name),
)
ax.legend(loc=4, fontsize=LegendFontsize)
ax.set_xlabel("ranking on the {:} validation".format(name), fontsize=LabelSize)
ax.set_ylabel("architecture ranking", fontsize=LabelSize)
labels = get_labels(cifar010_info)
plot_ax(labels, ax1, "CIFAR-10")
labels = get_labels(cifar100_info)
plot_ax(labels, ax2, "CIFAR-100")
labels = get_labels(imagenet_info)
plot_ax(labels, ax3, "ImageNet-16-120")
save_path = (
vis_save_dir / "{:}-same-relative-rank.pdf".format(indicator)
).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="pdf")
save_path = (
vis_save_dir / "{:}-same-relative-rank.png".format(indicator)
).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
def calculate_correlation(*vectors):
matrix = []
for i, vectori in enumerate(vectors):
x = []
for j, vectorj in enumerate(vectors):
x.append(np.corrcoef(vectori, vectorj)[0, 1])
matrix.append(x)
return np.array(matrix)
def visualize_all_rank_info(api, vis_save_dir, indicator):
vis_save_dir = vis_save_dir.resolve()
# print ('{:} start to visualize {:} information'.format(time_string(), api))
vis_save_dir.mkdir(parents=True, exist_ok=True)
cifar010_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar10", indicator
)
cifar100_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"cifar100", indicator
)
imagenet_cache_path = vis_save_dir / "{:}-cache-{:}-info.pth".format(
"ImageNet16-120", indicator
)
cifar010_info = torch.load(cifar010_cache_path)
cifar100_info = torch.load(cifar100_cache_path)
imagenet_info = torch.load(imagenet_cache_path)
indexes = list(range(len(cifar010_info["params"])))
print("{:} start to visualize relative ranking".format(time_string()))
dpi, width, height = 250, 3200, 1400
figsize = width / float(dpi), height / float(dpi)
LabelSize, LegendFontsize = 14, 14
fig, axs = plt.subplots(1, 2, figsize=figsize)
ax1, ax2 = axs
sns_size = 15
CoRelMatrix = calculate_correlation(
cifar010_info["valid_accs"],
cifar010_info["test_accs"],
cifar100_info["valid_accs"],
cifar100_info["test_accs"],
imagenet_info["valid_accs"],
imagenet_info["test_accs"],
)
sns.heatmap(
CoRelMatrix,
annot=True,
annot_kws={"size": sns_size},
fmt=".3f",
linewidths=0.5,
ax=ax1,
xticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"],
yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"],
)
selected_indexes, acc_bar = [], 92
for i, acc in enumerate(cifar010_info["test_accs"]):
if acc > acc_bar:
selected_indexes.append(i)
cifar010_valid_accs = np.array(cifar010_info["valid_accs"])[selected_indexes]
cifar010_test_accs = np.array(cifar010_info["test_accs"])[selected_indexes]
cifar100_valid_accs = np.array(cifar100_info["valid_accs"])[selected_indexes]
cifar100_test_accs = np.array(cifar100_info["test_accs"])[selected_indexes]
imagenet_valid_accs = np.array(imagenet_info["valid_accs"])[selected_indexes]
imagenet_test_accs = np.array(imagenet_info["test_accs"])[selected_indexes]
CoRelMatrix = calculate_correlation(
cifar010_valid_accs,
cifar010_test_accs,
cifar100_valid_accs,
cifar100_test_accs,
imagenet_valid_accs,
imagenet_test_accs,
)
sns.heatmap(
CoRelMatrix,
annot=True,
annot_kws={"size": sns_size},
fmt=".3f",
linewidths=0.5,
ax=ax2,
xticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"],
yticklabels=["C10-V", "C10-T", "C100-V", "C100-T", "I120-V", "I120-T"],
)
ax1.set_title("Correlation coefficient over ALL candidates")
ax2.set_title(
"Correlation coefficient over candidates with accuracy > {:}%".format(acc_bar)
)
save_path = (vis_save_dir / "{:}-all-relative-rank.png".format(indicator)).resolve()
fig.savefig(save_path, dpi=dpi, bbox_inches="tight", format="png")
print("{:} save into {:}".format(time_string(), save_path))
plt.close("all")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="NAS-Bench-X",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--save_dir",
type=str,
default="output/vis-nas-bench",
help="Folder to save checkpoints and log.",
)
# use for train the model
args = parser.parse_args()
to_save_dir = Path(args.save_dir)
datasets = ["cifar10", "cifar100", "ImageNet16-120"]
api201 = create(None, "tss", verbose=True)
for xdata in datasets:
visualize_tss_info(api201, xdata, to_save_dir)
api_sss = create(None, "size", verbose=True)
for xdata in datasets:
visualize_sss_info(api_sss, xdata, to_save_dir)
visualize_info(None, to_save_dir, "tss")
visualize_info(None, to_save_dir, "sss")
visualize_rank_info(None, to_save_dir, "tss")
visualize_rank_info(None, to_save_dir, "sss")
visualize_all_rank_info(None, to_save_dir, "tss")
visualize_all_rank_info(None, to_save_dir, "sss")