Fix 1-element in norm bug

This commit is contained in:
D-X-Y
2021-05-12 19:09:17 +08:00
parent 80ccc49d92
commit 06f4a1f1cf
3 changed files with 15 additions and 8 deletions

View File

@@ -58,6 +58,7 @@ def main(args):
# build model
model = get_model(**model_kwargs)
print(model)
model.analyze_weights()
# build optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
criterion = torch.nn.MSELoss()
@@ -168,7 +169,7 @@ if __name__ == "__main__":
parser.add_argument(
"--epochs",
type=int,
default=1000,
default=300,
help="The total number of epochs.",
)
parser.add_argument(