Use black for lib/models
This commit is contained in:
@@ -41,10 +41,14 @@ def main(args):
|
||||
shape_container = model.get_w_container().to_shape_container()
|
||||
hypernet = HyperNet(shape_container, args.hidden_dim, args.task_dim)
|
||||
# task_embed = torch.nn.Parameter(torch.Tensor(env_info["total"], args.task_dim))
|
||||
task_embed = torch.nn.Parameter(torch.Tensor(1, args.task_dim))
|
||||
trunc_normal_(task_embed, std=0.02)
|
||||
total_bar = 10
|
||||
task_embeds = []
|
||||
for i in range(total_bar):
|
||||
task_embeds.append(torch.nn.Parameter(torch.Tensor(1, args.task_dim)))
|
||||
for task_embed in task_embeds:
|
||||
trunc_normal_(task_embed, std=0.02)
|
||||
|
||||
parameters = list(hypernet.parameters()) + [task_embed]
|
||||
parameters = list(hypernet.parameters()) + task_embeds
|
||||
optimizer = torch.optim.Adam(parameters, lr=args.init_lr, amsgrad=True)
|
||||
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||
optimizer,
|
||||
@@ -56,7 +60,6 @@ def main(args):
|
||||
)
|
||||
|
||||
# total_bar = env_info["total"] - 1
|
||||
total_bar = 1
|
||||
# LFNA meta-training
|
||||
loss_meter = AverageMeter()
|
||||
per_epoch_time, start_time = AverageMeter(), time.time()
|
||||
@@ -74,7 +77,7 @@ def main(args):
|
||||
# for ibatch in range(args.meta_batch):
|
||||
for cur_time in range(total_bar):
|
||||
# cur_time = random.randint(0, total_bar)
|
||||
cur_task_embed = task_embed
|
||||
cur_task_embed = task_embeds[cur_time]
|
||||
cur_container = hypernet(cur_task_embed)
|
||||
cur_x = env_info["{:}-x".format(cur_time)]
|
||||
cur_y = env_info["{:}-y".format(cur_time)]
|
||||
@@ -98,7 +101,7 @@ def main(args):
|
||||
+ "meta-loss: {:.4f} ({:.4f}) :: lr={:.5f}, batch={:}".format(
|
||||
loss_meter.avg,
|
||||
loss_meter.val,
|
||||
min(lr_scheduler.get_lr()),
|
||||
min(lr_scheduler.get_last_lr()),
|
||||
len(losses),
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user