added training code
This commit is contained in:
7
train.py
7
train.py
@@ -39,7 +39,7 @@ except:
|
||||
|
||||
|
||||
# exclude extremly large displacements
|
||||
MAX_FLOW = 500
|
||||
MAX_FLOW = 400
|
||||
SUM_FREQ = 100
|
||||
VAL_FREQ = 5000
|
||||
|
||||
@@ -181,13 +181,14 @@ def train(args):
|
||||
|
||||
loss, metrics = sequence_loss(flow_predictions, flow, valid)
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
scaler.unscale_(optimizer)
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scheduler.step()
|
||||
scaler.update()
|
||||
|
||||
|
||||
logger.push(metrics)
|
||||
|
||||
if total_steps % VAL_FREQ == VAL_FREQ - 1:
|
||||
|
||||
Reference in New Issue
Block a user