Use black for lib/models

This commit is contained in:
D-X-Y
2021-05-12 16:28:05 +08:00
parent d51e5fdc7f
commit f1c47af5fa
42 changed files with 7552 additions and 4688 deletions

View File

@@ -41,10 +41,14 @@ def main(args):
shape_container = model.get_w_container().to_shape_container()
hypernet = HyperNet(shape_container, args.hidden_dim, args.task_dim)
# task_embed = torch.nn.Parameter(torch.Tensor(env_info["total"], args.task_dim))
task_embed = torch.nn.Parameter(torch.Tensor(1, args.task_dim))
trunc_normal_(task_embed, std=0.02)
total_bar = 10
task_embeds = []
for i in range(total_bar):
task_embeds.append(torch.nn.Parameter(torch.Tensor(1, args.task_dim)))
for task_embed in task_embeds:
trunc_normal_(task_embed, std=0.02)
parameters = list(hypernet.parameters()) + [task_embed]
parameters = list(hypernet.parameters()) + task_embeds
optimizer = torch.optim.Adam(parameters, lr=args.init_lr, amsgrad=True)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
@@ -56,7 +60,6 @@ def main(args):
)
# total_bar = env_info["total"] - 1
total_bar = 1
# LFNA meta-training
loss_meter = AverageMeter()
per_epoch_time, start_time = AverageMeter(), time.time()
@@ -74,7 +77,7 @@ def main(args):
# for ibatch in range(args.meta_batch):
for cur_time in range(total_bar):
# cur_time = random.randint(0, total_bar)
cur_task_embed = task_embed
cur_task_embed = task_embeds[cur_time]
cur_container = hypernet(cur_task_embed)
cur_x = env_info["{:}-x".format(cur_time)]
cur_y = env_info["{:}-y".format(cur_time)]
@@ -98,7 +101,7 @@ def main(args):
+ "meta-loss: {:.4f} ({:.4f}) :: lr={:.5f}, batch={:}".format(
loss_meter.avg,
loss_meter.val,
min(lr_scheduler.get_lr()),
min(lr_scheduler.get_last_lr()),
len(losses),
)
)