Re-org GeMOSA codes

This commit is contained in:
D-X-Y
2021-05-27 11:17:57 +08:00
parent a507f8dd94
commit 8961215416
8 changed files with 82 additions and 162 deletions

View File

@@ -1,117 +0,0 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
import copy
import torch
import torch.nn.functional as F
from xlayers import super_core
from xlayers import trunc_normal_
from models.xcore import get_model
class HyperNet(super_core.SuperModule):
"""The hyper-network."""
def __init__(
self,
shape_container,
layer_embeding,
task_embedding,
num_tasks,
return_container=True,
):
super(HyperNet, self).__init__()
self._shape_container = shape_container
self._num_layers = len(shape_container)
self._numel_per_layer = []
for ilayer in range(self._num_layers):
self._numel_per_layer.append(shape_container[ilayer].numel())
self.register_parameter(
"_super_layer_embed",
torch.nn.Parameter(torch.Tensor(self._num_layers, layer_embeding)),
)
self.register_parameter(
"_super_task_embed",
torch.nn.Parameter(torch.Tensor(num_tasks, task_embedding)),
)
trunc_normal_(self._super_layer_embed, std=0.02)
trunc_normal_(self._super_task_embed, std=0.02)
model_kwargs = dict(
config=dict(model_type="dual_norm_mlp"),
input_dim=layer_embeding + task_embedding,
output_dim=max(self._numel_per_layer),
hidden_dims=[(layer_embeding + task_embedding) * 2] * 3,
act_cls="gelu",
norm_cls="layer_norm_1d",
dropout=0.2,
)
self._generator = get_model(**model_kwargs)
self._return_container = return_container
print("generator: {:}".format(self._generator))
def forward_raw(self, task_embed_id):
layer_embed = self._super_layer_embed
task_embed = (
self._super_task_embed[task_embed_id]
.view(1, -1)
.expand(self._num_layers, -1)
)
joint_embed = torch.cat((task_embed, layer_embed), dim=-1)
weights = self._generator(joint_embed)
if self._return_container:
weights = torch.split(weights, 1)
return self._shape_container.translate(weights)
else:
return weights
def forward_candidate(self, input):
raise NotImplementedError
def extra_repr(self) -> str:
return "(_super_layer_embed): {:}".format(list(self._super_layer_embed.shape))
class HyperNet_VX(super_core.SuperModule):
def __init__(self, shape_container, input_embeding, return_container=True):
super(HyperNet_VX, self).__init__()
self._shape_container = shape_container
self._num_layers = len(shape_container)
self._numel_per_layer = []
for ilayer in range(self._num_layers):
self._numel_per_layer.append(shape_container[ilayer].numel())
self.register_parameter(
"_super_layer_embed",
torch.nn.Parameter(torch.Tensor(self._num_layers, input_embeding)),
)
trunc_normal_(self._super_layer_embed, std=0.02)
model_kwargs = dict(
input_dim=input_embeding,
output_dim=max(self._numel_per_layer),
hidden_dim=input_embeding * 4,
act_cls="sigmoid",
norm_cls="identity",
)
self._generator = get_model(dict(model_type="simple_mlp"), **model_kwargs)
self._return_container = return_container
print("generator: {:}".format(self._generator))
def forward_raw(self, input):
weights = self._generator(self._super_layer_embed)
if self._return_container:
weights = torch.split(weights, 1)
return self._shape_container.translate(weights)
else:
return weights
def forward_candidate(self, input):
raise NotImplementedError
def extra_repr(self) -> str:
return "(_super_layer_embed): {:}".format(list(self._super_layer_embed.shape))

View File

@@ -35,7 +35,7 @@ from xautodl.models.xcore import get_model
from xautodl.xlayers import super_core, trunc_normal_
from lfna_utils import lfna_setup, train_model, TimeData
from lfna_meta_model import MetaModelV1
from meta_model import MetaModelV1
def online_evaluate(env, meta_model, base_model, criterion, args, logger, save=False):
@@ -106,7 +106,7 @@ def meta_train_procedure(base_model, meta_model, criterion, xenv, args, logger):
)
optimizer.zero_grad()
generated_time_embeds = meta_model(meta_model.meta_timestamps, None, True)
generated_time_embeds = gen_time_embed(meta_model.meta_timestamps)
batch_indexes = random.choices(total_indexes, k=args.meta_batch)
@@ -219,11 +219,11 @@ def main(args):
w_containers, loss_meter = online_evaluate(
all_env, meta_model, base_model, criterion, args, logger, True
)
logger.log("In this enviornment, the loss-meter is {:}".format(loss_meter))
logger.log("In this enviornment, the total loss-meter is {:}".format(loss_meter))
save_checkpoint(
{"w_containers": w_containers},
logger.path(None) / "final-ckp.pth",
{"all_w_containers": w_containers},
logger.path(None) / "final-ckp-{:}.pth".format(args.rand_seed),
logger,
)

View File

@@ -154,8 +154,9 @@ class MetaModelV1(super_core.SuperModule):
(self._append_meta_embed["fixed"], meta_embed), dim=0
)
def _obtain_time_embed(self, timestamps):
# timestamps is a batch of sequence of timestamps
def gen_time_embed(self, timestamps):
# timestamps is a batch of timestamps
[B] = timestamps.shape
# batch, seq = timestamps.shape
timestamps = timestamps.view(-1, 1)
meta_timestamps, meta_embeds = self.meta_timestamps, self.super_meta_embed
@@ -179,15 +180,8 @@ class MetaModelV1(super_core.SuperModule):
)
return timestamp_embeds[:, -1, :]
def forward_raw(self, timestamps, time_embeds, tembed_only=False):
if time_embeds is None:
[B] = timestamps.shape
time_embeds = self._obtain_time_embed(timestamps)
else: # use the hyper-net only
time_seq = None
B, _ = time_embeds.shape
if tembed_only:
return time_embeds
def gen_model(self, time_embeds):
B, _ = time_embeds.shape
# create joint embed
num_layer, _ = self._super_layer_embed.shape
# The shape of `joint_embed` is batch * num-layers * input-dim
@@ -206,6 +200,9 @@ class MetaModelV1(super_core.SuperModule):
)
return batch_containers, time_embeds
def forward_raw(self, timestamps, time_embeds, tembed_only=False):
raise NotImplementedError
def forward_candidate(self, input):
raise NotImplementedError

View File

@@ -1,8 +1,9 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
############################################################################
# python exps/GMOA/vis-synthetic.py --env_version v1 #
# python exps/GMOA/vis-synthetic.py --env_version v2 #
# python exps/GeMOSA/vis-synthetic.py --env_version v1 #
# python exps/GeMOSA/vis-synthetic.py --env_version v2 #
# python exps/GeMOSA/vis-synthetic.py --env_version v2 #
############################################################################
import os, sys, copy, random
import torch
@@ -181,7 +182,7 @@ def compare_cl(save_dir):
def visualize_env(save_dir, version):
save_dir = Path(str(save_dir))
for substr in ("pdf", "png"):
sub_save_dir = save_dir / substr
sub_save_dir = save_dir / "{:}-{:}".format(substr, version)
sub_save_dir.mkdir(parents=True, exist_ok=True)
dynamic_env = get_synthetic_env(version=version)
@@ -190,6 +191,8 @@ def visualize_env(save_dir, version):
allxs.append(allx)
allys.append(ally)
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item()))
print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item()))
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
dpi, width, height = 30, 1800, 1400
figsize = width / float(dpi), height / float(dpi)
@@ -210,14 +213,22 @@ def visualize_env(save_dir, version):
cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1))
cur_ax.legend(loc=1, fontsize=LegendFontsize)
pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx)
pdf_save_path = (
save_dir
/ "pdf-{:}".format(version)
/ "v{:}-{:05d}.pdf".format(version, idx)
)
fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf")
png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx)
png_save_path = (
save_dir
/ "png-{:}".format(version)
/ "v{:}-{:05d}.png".format(version, idx)
)
fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png")
plt.close("all")
save_dir = save_dir.resolve()
base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format(
xdir=save_dir / "png", version=version
xdir=save_dir / "png-{:}".format(version), version=version
)
print(base_cmd)
os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version))
@@ -367,7 +378,7 @@ if __name__ == "__main__":
)
args = parser.parse_args()
# visualize_env(os.path.join(args.save_dir, "vis-env"), "v1")
visualize_env(os.path.join(args.save_dir, "vis-env"), args.env_version)
# visualize_env(os.path.join(args.save_dir, "vis-env"), "v2")
compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version)
# compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version)
# compare_cl(os.path.join(args.save_dir, "compare-cl"))