Update metrics
This commit is contained in:
@@ -69,10 +69,13 @@ def main(args):
|
||||
weight_decay=args.weight_decay,
|
||||
)
|
||||
objective = xmisc.nested_call_by_yaml(args.loss_config)
|
||||
metric = xmisc.nested_call_by_yaml(args.metric_config)
|
||||
|
||||
logger.log("The optimizer is:\n{:}".format(optimizer))
|
||||
logger.log("The objective is {:}".format(objective))
|
||||
logger.log("The iters_per_epoch={:}".format(iters_per_epoch))
|
||||
logger.log("The metric is {:}".format(metric))
|
||||
logger.log("The iters_per_epoch = {:}, estimated epochs = {:}".format(
|
||||
iters_per_epoch, args.steps // iters_per_epoch))
|
||||
|
||||
model, objective = torch.nn.DataParallel(model).cuda(), objective.cuda()
|
||||
scheduler = xmisc.LRMultiplier(
|
||||
@@ -99,6 +102,7 @@ def main(args):
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
if xiter % iters_per_epoch == 0:
|
||||
logger.log("TRAIN [{:}] loss = {:.6f}".format(iter_str, loss.item()))
|
||||
|
||||
@@ -123,6 +127,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--model_config", type=str, help="The path to the model config")
|
||||
parser.add_argument("--optim_config", type=str, help="The optimizer config file.")
|
||||
parser.add_argument("--loss_config", type=str, help="The loss config file.")
|
||||
parser.add_argument("--metric_config", type=str, help="The metric config file.")
|
||||
parser.add_argument(
|
||||
"--train_data_config", type=str, help="The training dataset config path."
|
||||
)
|
||||
|
Reference in New Issue
Block a user