added small 1M paramter model

This commit is contained in:
Zach Teed
2020-08-23 22:40:47 -06:00
parent c86b3dc8f3
commit 01ad964d94
5 changed files with 18 additions and 27 deletions

View File

@@ -44,7 +44,7 @@ SUM_FREQ = 100
VAL_FREQ = 5000
def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW):
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW):
""" Loss function defined over sequence of flow predictions """
n_predictions = len(flow_preds)
@@ -55,7 +55,7 @@ def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW):
valid = (valid >= 0.5) & (mag < max_flow)
for i in range(n_predictions):
i_weight = 0.8**(n_predictions - i - 1)
i_weight = gamma**(n_predictions - i - 1)
i_loss = (flow_preds[i] - flow_gt).abs()
flow_loss += i_weight * (valid[:, None] * i_loss).mean()
@@ -71,16 +71,11 @@ def sequence_loss(flow_preds, flow_gt, valid, max_flow=MAX_FLOW):
return flow_loss, metrics
def show_image(img):
img = img.permute(1,2,0).cpu().numpy()
plt.imshow(img/255.0)
plt.show()
# cv2.imshow('image', img/255.0)
# cv2.waitKey()
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def fetch_optimizer(args, model):
""" Create the optimizer and learning rate scheduler """
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=args.epsilon)
@@ -169,9 +164,6 @@ def train(args):
optimizer.zero_grad()
image1, image2, flow, valid = [x.cuda() for x in data_blob]
# show_image(image1[0])
# show_image(image2[0])
if args.add_noise:
stdv = np.random.uniform(0.0, 5.0)
image1 = (image1 + stdv * torch.randn(*image1.shape).cuda()).clamp(0.0, 255.0)
@@ -179,7 +171,7 @@ def train(args):
flow_predictions = model(image1, image2, iters=args.iters)
loss, metrics = sequence_loss(flow_predictions, flow, valid)
loss, metrics = sequence_loss(flow_predictions, flow, valid, args.gamma)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
@@ -188,7 +180,6 @@ def train(args):
scheduler.step()
scaler.update()
logger.push(metrics)
if total_steps % VAL_FREQ == VAL_FREQ - 1:
@@ -243,6 +234,7 @@ if __name__ == '__main__':
parser.add_argument('--epsilon', type=float, default=1e-8)
parser.add_argument('--clip', type=float, default=1.0)
parser.add_argument('--dropout', type=float, default=0.0)
parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting')
parser.add_argument('--add_noise', action='store_true')
args = parser.parse_args()