added small 1M paramter model
This commit is contained in:
18
train.py
18
train.py
@@ -44,7 +44,7 @@ SUM_FREQ = 100
|
||||
VAL_FREQ = 5000
|
||||
|
||||
|
||||
def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW):
|
||||
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW):
|
||||
""" Loss function defined over sequence of flow predictions """
|
||||
|
||||
n_predictions = len(flow_preds)
|
||||
@@ -55,7 +55,7 @@ def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW):
|
||||
valid = (valid >= 0.5) & (mag < max_flow)
|
||||
|
||||
for i in range(n_predictions):
|
||||
i_weight = 0.8**(n_predictions - i - 1)
|
||||
i_weight = gamma**(n_predictions - i - 1)
|
||||
i_loss = (flow_preds[i] - flow_gt).abs()
|
||||
flow_loss += i_weight * (valid[:, None] * i_loss).mean()
|
||||
|
||||
@@ -71,16 +71,11 @@ def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW):
|
||||
|
||||
return flow_loss, metrics
|
||||
|
||||
def show_image(img):
|
||||
img = img.permute(1,2,0).cpu().numpy()
|
||||
plt.imshow(img/255.0)
|
||||
plt.show()
|
||||
# cv2.imshow('image', img/255.0)
|
||||
# cv2.waitKey()
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
def fetch_optimizer(args, model):
|
||||
""" Create the optimizer and learning rate scheduler """
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
|
||||
@@ -169,9 +164,6 @@ def train(args):
|
||||
optimizer.zero_grad()
|
||||
image1, image2, flow, valid = [x.cuda() for x in data_blob]
|
||||
|
||||
# show_image(image1[0])
|
||||
# show_image(image2[0])
|
||||
|
||||
if args.add_noise:
|
||||
stdv = np.random.uniform(0.0, 5.0)
|
||||
image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
|
||||
@@ -179,7 +171,7 @@ def train(args):
|
||||
|
||||
flow_predictions = model(image1, image2, iters=args.iters)
|
||||
|
||||
loss, metrics = sequence_loss(flow_predictions, flow, valid)
|
||||
loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma)
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
|
||||
@@ -188,7 +180,6 @@ def train(args):
|
||||
scheduler.step()
|
||||
scaler.update()
|
||||
|
||||
|
||||
logger.push(metrics)
|
||||
|
||||
if total_steps % VAL_FREQ == VAL_FREQ - 1:
|
||||
@@ -243,6 +234,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--epsilon', type=float, default=1e-8)
|
||||
parser.add_argument('--clip', type=float, default=1.0)
|
||||
parser.add_argument('--dropout', type=float, default=0.0)
|
||||
parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting')
|
||||
parser.add_argument('--add_noise', action='store_true')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user