added training code

This commit is contained in:
Zach Teed
2020-07-30 21:25:36 -06:00
parent dc370f877b
commit a1d8344039
5 changed files with 25 additions and 13 deletions

View File

@@ -39,7 +39,7 @@ except:
# exclude extremly large displacements
MAX_FLOW = 500
MAX_FLOW = 400
SUM_FREQ = 100
VAL_FREQ = 5000
@@ -181,13 +181,14 @@ def train(args):
loss, metrics = sequence_loss(flow_predictions, flow, valid)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
scaler.step(optimizer)
scheduler.step()
scaler.update()
logger.push(metrics)
if total_steps % VAL_FREQ == VAL_FREQ - 1: