release cotracker 2.0

This commit is contained in:
Nikita Karaev
2023-12-27 12:54:02 +00:00
parent 3df96621ed
commit f8fab323c4
38 changed files with 2238 additions and 1910 deletions

358
train.py
View File

@@ -25,22 +25,35 @@ from torch.utils.tensorboard import SummaryWriter
from pytorch_lightning.lite import LightningLite
from cotracker.models.evaluation_predictor import EvaluationPredictor
from cotracker.models.core.cotracker.cotracker import CoTracker
from cotracker.models.core.cotracker.cotracker import CoTracker2
from cotracker.utils.visualizer import Visualizer
from cotracker.datasets.tap_vid_datasets import TapVidDataset
from cotracker.datasets.badja_dataset import BadjaDataset
from cotracker.datasets.fast_capture_dataset import FastCaptureDataset
from cotracker.datasets.dr_dataset import DynamicReplicaDataset
from cotracker.evaluation.core.evaluator import Evaluator
from cotracker.datasets import kubric_movif_dataset
from cotracker.datasets.utils import collate_fn, collate_fn_train, dataclass_to_cuda_
from cotracker.models.core.cotracker.losses import sequence_loss, balanced_ce_loss
# define the handler function
# for training on a slurm cluster
def sig_handler(signum, frame):
print("caught signal", signum)
print(socket.gethostname(), "USR1 signal caught.")
# do other stuff to cleanup here
print("requeuing job " + os.environ["SLURM_JOB_ID"])
os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"])
sys.exit(-1)
def term_handler(signum, frame):
print("bypassing sigterm", flush=True)
def fetch_optimizer(args, model):
"""Create the optimizer and learning rate scheduler"""
optimizer = optim.AdamW(
model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8
)
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8)
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
args.lr,
@@ -53,69 +66,61 @@ def fetch_optimizer(args, model):
return optimizer, scheduler
def forward_batch(batch, model, args, loss_fn=None, writer=None, step=0):
rgbs = batch.video
def forward_batch(batch, model, args):
video = batch.video
trajs_g = batch.trajectory
vis_g = batch.visibility
valids = batch.valid
B, T, C, H, W = rgbs.shape
B, T, C, H, W = video.shape
assert C == 3
B, T, N, D = trajs_g.shape
device = rgbs.device
device = video.device
__, first_positive_inds = torch.max(vis_g, dim=1)
# We want to make sure that during training the model sees visible points
# that it does not need to track just yet: they are visible but queried from a later frame
N_rand = N // 4
# inds of visible points in the 1st frame
nonzero_inds = [torch.nonzero(vis_g[0, :, i]) for i in range(N)]
rand_vis_inds = torch.cat(
[
nonzero_row[torch.randint(len(nonzero_row), size=(1,))]
for nonzero_row in nonzero_inds
],
dim=1,
)
first_positive_inds = torch.cat(
[rand_vis_inds[:, :N_rand], first_positive_inds[:, N_rand:]], dim=1
)
nonzero_inds = [[torch.nonzero(vis_g[b, :, i]) for i in range(N)] for b in range(B)]
for b in range(B):
rand_vis_inds = torch.cat(
[
nonzero_row[torch.randint(len(nonzero_row), size=(1,))]
for nonzero_row in nonzero_inds[b]
],
dim=1,
)
first_positive_inds[b] = torch.cat(
[rand_vis_inds[:, :N_rand], first_positive_inds[b : b + 1, N_rand:]], dim=1
)
ind_array_ = torch.arange(T, device=device)
ind_array_ = ind_array_[None, :, None].repeat(B, 1, N)
assert torch.allclose(
vis_g[ind_array_ == first_positive_inds[:, None, :]],
torch.ones_like(vis_g),
)
assert torch.allclose(
vis_g[ind_array_ == rand_vis_inds[:, None, :]], torch.ones_like(vis_g)
)
gather = torch.gather(
trajs_g, 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, 2)
torch.ones(1, device=device),
)
gather = torch.gather(trajs_g, 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, D))
xys = torch.diagonal(gather, dim1=1, dim2=2).permute(0, 2, 1)
queries = torch.cat([first_positive_inds[:, :, None], xys], dim=2)
queries = torch.cat([first_positive_inds[:, :, None], xys[:, :, :2]], dim=2)
predictions, __, visibility, train_data = model(
rgbs=rgbs, queries=queries, iters=args.train_iters, is_train=True
predictions, visibility, train_data = model(
video=video, queries=queries, iters=args.train_iters, is_train=True
)
vis_predictions, coord_predictions, wind_inds, sort_inds = train_data
trajs_g = trajs_g[:, :, sort_inds]
vis_g = vis_g[:, :, sort_inds]
valids = valids[:, :, sort_inds]
coord_predictions, vis_predictions, valid_mask = train_data
vis_gts = []
traj_gts = []
valids_gts = []
for i, wind_idx in enumerate(wind_inds):
ind = i * (args.sliding_window_len // 2)
vis_gts.append(vis_g[:, ind : ind + args.sliding_window_len, :wind_idx])
traj_gts.append(trajs_g[:, ind : ind + args.sliding_window_len, :wind_idx])
valids_gts.append(valids[:, ind : ind + args.sliding_window_len, :wind_idx])
S = args.sliding_window_len
for ind in range(0, args.sequence_len - S // 2, S // 2):
vis_gts.append(vis_g[:, ind : ind + S])
traj_gts.append(trajs_g[:, ind : ind + S])
valids_gts.append(valids[:, ind : ind + S] * valid_mask[:, ind : ind + S])
seq_loss = sequence_loss(coord_predictions, traj_gts, vis_gts, valids_gts, 0.8)
vis_loss = balanced_ce_loss(vis_predictions, vis_gts, valids_gts)
@@ -131,9 +136,17 @@ def forward_batch(batch, model, args, loss_fn=None, writer=None, step=0):
def run_test_eval(evaluator, model, dataloaders, writer, step):
model.eval()
for ds_name, dataloader in dataloaders:
visualize_every = 1
grid_size = 5
if ds_name == "dynamic_replica":
visualize_every = 8
grid_size = 0
elif "tapvid" in ds_name:
visualize_every = 5
predictor = EvaluationPredictor(
model.module.module,
grid_size=6,
grid_size=grid_size,
local_grid_size=0,
single_point=False,
n_iters=6,
@@ -148,37 +161,23 @@ def run_test_eval(evaluator, model, dataloaders, writer, step):
train_mode=True,
writer=writer,
step=step,
visualize_every=visualize_every,
)
if ds_name == "badja" or ds_name == "fastcapture" or ("kubric" in ds_name):
metrics = {
**{
f"{ds_name}_avg": np.mean(
[v for k, v in metrics.items() if "accuracy" not in k]
)
},
**{
f"{ds_name}_avg_accuracy": np.mean(
[v for k, v in metrics.items() if "accuracy" in k]
)
},
}
print("avg", np.mean([v for v in metrics.values()]))
if ds_name == "dynamic_replica" or ds_name == "kubric":
metrics = {f"{ds_name}_avg_{k}": v for k, v in metrics["avg"].items()}
if "tapvid" in ds_name:
metrics = {
f"{ds_name}_avg_OA": metrics["avg"]["occlusion_accuracy"] * 100,
f"{ds_name}_avg_delta": metrics["avg"]["average_pts_within_thresh"]
* 100,
f"{ds_name}_avg_Jaccard": metrics["avg"]["average_jaccard"] * 100,
f"{ds_name}_avg_OA": metrics["avg"]["occlusion_accuracy"],
f"{ds_name}_avg_delta": metrics["avg"]["average_pts_within_thresh"],
f"{ds_name}_avg_Jaccard": metrics["avg"]["average_jaccard"],
}
writer.add_scalars(f"Eval", metrics, step)
writer.add_scalars(f"Eval_{ds_name}", metrics, step)
class Logger:
SUM_FREQ = 100
def __init__(self, model, scheduler):
@@ -190,24 +189,19 @@ class Logger:
def _print_training_status(self):
metrics_data = [
self.running_loss[k] / Logger.SUM_FREQ
for k in sorted(self.running_loss.keys())
self.running_loss[k] / Logger.SUM_FREQ for k in sorted(self.running_loss.keys())
]
training_str = "[{:6d}] ".format(self.total_steps + 1)
metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data)
# print the training status
logging.info(
f"Training Metrics ({self.total_steps}): {training_str + metrics_str}"
)
logging.info(f"Training Metrics ({self.total_steps}): {training_str + metrics_str}")
if self.writer is None:
self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs"))
for k in self.running_loss:
self.writer.add_scalar(
k, self.running_loss[k] / Logger.SUM_FREQ, self.total_steps
)
self.writer.add_scalar(k, self.running_loss[k] / Logger.SUM_FREQ, self.total_steps)
self.running_loss[k] = 0.0
def push(self, metrics, task):
@@ -249,79 +243,56 @@ class Lite(LightningLite):
seed_everything(0)
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2 ** 32
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
g = torch.Generator()
g.manual_seed(0)
if self.global_rank == 0:
eval_dataloaders = []
if "dynamic_replica" in args.eval_datasets:
eval_dataset = DynamicReplicaDataset(
sample_len=60, only_first_n_samples=1, rgbd_input=False
)
eval_dataloader_dr = torch.utils.data.DataLoader(
eval_dataset,
batch_size=1,
shuffle=False,
num_workers=1,
collate_fn=collate_fn,
)
eval_dataloaders.append(("dynamic_replica", eval_dataloader_dr))
eval_dataloaders = []
if "badja" in args.eval_datasets:
eval_dataset = BadjaDataset(
data_root=os.path.join(args.dataset_root, "BADJA"),
max_seq_len=args.eval_max_seq_len,
dataset_resolution=args.crop_size,
if "tapvid_davis_first" in args.eval_datasets:
data_root = os.path.join(args.dataset_root, "tapvid/tapvid_davis/tapvid_davis.pkl")
eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root)
eval_dataloader_tapvid_davis = torch.utils.data.DataLoader(
eval_dataset,
batch_size=1,
shuffle=False,
num_workers=1,
collate_fn=collate_fn,
)
eval_dataloaders.append(("tapvid_davis", eval_dataloader_tapvid_davis))
evaluator = Evaluator(args.ckpt_path)
visualizer = Visualizer(
save_dir=args.ckpt_path,
pad_value=80,
fps=1,
show_first_frame=0,
tracks_leave_trace=0,
)
eval_dataloader_badja = torch.utils.data.DataLoader(
eval_dataset,
batch_size=1,
shuffle=False,
num_workers=8,
collate_fn=collate_fn,
)
eval_dataloaders.append(("badja", eval_dataloader_badja))
if "fastcapture" in args.eval_datasets:
eval_dataset = FastCaptureDataset(
data_root=os.path.join(args.dataset_root, "fastcapture"),
max_seq_len=min(100, args.eval_max_seq_len),
max_num_points=40,
dataset_resolution=args.crop_size,
)
eval_dataloader_fastcapture = torch.utils.data.DataLoader(
eval_dataset,
batch_size=1,
shuffle=False,
num_workers=1,
collate_fn=collate_fn,
)
eval_dataloaders.append(("fastcapture", eval_dataloader_fastcapture))
if "tapvid_davis_first" in args.eval_datasets:
data_root = os.path.join(args.dataset_root, "tapvid_davis/tapvid_davis.pkl")
eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root)
eval_dataloader_tapvid_davis = torch.utils.data.DataLoader(
eval_dataset,
batch_size=1,
shuffle=False,
num_workers=1,
collate_fn=collate_fn,
)
eval_dataloaders.append(("tapvid_davis", eval_dataloader_tapvid_davis))
evaluator = Evaluator(args.ckpt_path)
visualizer = Visualizer(
save_dir=args.ckpt_path,
pad_value=80,
fps=1,
show_first_frame=0,
tracks_leave_trace=0,
)
loss_fn = None
if args.model_name == "cotracker":
model = CoTracker(
model = CoTracker2(
stride=args.model_stride,
S=args.sliding_window_len,
window_len=args.sliding_window_len,
add_space_attn=not args.remove_space_attn,
num_heads=args.updateformer_num_heads,
hidden_size=args.updateformer_hidden_size,
space_depth=args.updateformer_space_depth,
time_depth=args.updateformer_time_depth,
num_virtual_tracks=args.num_virtual_tracks,
model_resolution=args.crop_size,
)
else:
raise ValueError(f"Model {args.model_name} doesn't exist")
@@ -332,7 +303,7 @@ class Lite(LightningLite):
model.cuda()
train_dataset = kubric_movif_dataset.KubricMovifDataset(
data_root=os.path.join(args.dataset_root, "kubric_movi_f"),
data_root=os.path.join(args.dataset_root, "kubric", "kubric_movi_f_tracks"),
crop_size=args.crop_size,
seq_len=args.sequence_len,
traj_per_sample=args.traj_per_sample,
@@ -357,7 +328,8 @@ class Lite(LightningLite):
optimizer, scheduler = fetch_optimizer(args, model)
total_steps = 0
logger = Logger(model, scheduler)
if self.global_rank == 0:
logger = Logger(model, scheduler)
folder_ckpts = [
f
@@ -383,9 +355,7 @@ class Lite(LightningLite):
logging.info(f"Load total_steps {total_steps}")
elif args.restore_ckpt is not None:
assert args.restore_ckpt.endswith(".pth") or args.restore_ckpt.endswith(
".pt"
)
assert args.restore_ckpt.endswith(".pth") or args.restore_ckpt.endswith(".pt")
logging.info("Loading checkpoint...")
strict = True
@@ -394,9 +364,7 @@ class Lite(LightningLite):
state_dict = state_dict["model"]
if list(state_dict.keys())[0].startswith("module."):
state_dict = {
k.replace("module.", ""): v for k, v in state_dict.items()
}
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict, strict=strict)
logging.info(f"Done loading checkpoint")
@@ -424,33 +392,22 @@ class Lite(LightningLite):
assert model.training
output = forward_batch(
batch,
model,
args,
loss_fn=loss_fn,
writer=logger.writer,
step=total_steps,
)
output = forward_batch(batch, model, args)
loss = 0
for k, v in output.items():
if "loss" in v:
loss += v["loss"]
logger.writer.add_scalar(
f"live_{k}_loss", v["loss"].item(), total_steps
)
if "metrics" in v:
logger.push(v["metrics"], k)
if self.global_rank == 0:
if total_steps % save_freq == save_freq - 1:
if args.model_name == "motion_diffuser":
pred_coords = model.module.module.forward_batch_test(
batch, interp_shape=args.crop_size
for k, v in output.items():
if "loss" in v:
logger.writer.add_scalar(
f"live_{k}_loss", v["loss"].item(), total_steps
)
output["flow"] = {"predictions": pred_coords[0].detach()}
if "metrics" in v:
logger.push(v["metrics"], k)
if total_steps % save_freq == save_freq - 1:
visualizer.visualize(
video=batch.video.clone(),
tracks=batch.trajectory.clone(),
@@ -468,9 +425,7 @@ class Lite(LightningLite):
)
if len(output) > 1:
logger.writer.add_scalar(
f"live_total_loss", loss.item(), total_steps
)
logger.writer.add_scalar(f"live_total_loss", loss.item(), total_steps)
logger.writer.add_scalar(
f"learning_rate", optimizer.param_groups[0]["lr"], total_steps
)
@@ -492,9 +447,7 @@ class Lite(LightningLite):
total_steps == 1 and args.validate_at_start
):
if (epoch + 1) % args.save_every_n_epoch == 0:
ckpt_iter = "0" * (6 - len(str(total_steps))) + str(
total_steps
)
ckpt_iter = "0" * (6 - len(str(total_steps))) + str(total_steps)
save_path = Path(
f"{args.ckpt_path}/model_{args.model_name}_{ckpt_iter}.pth"
)
@@ -526,16 +479,18 @@ class Lite(LightningLite):
if total_steps > args.num_steps:
should_keep_training = False
break
if self.global_rank == 0:
print("FINISHED TRAINING")
print("FINISHED TRAINING")
PATH = f"{args.ckpt_path}/{args.model_name}_final.pth"
torch.save(model.module.module.state_dict(), PATH)
run_test_eval(evaluator, model, eval_dataloaders, logger.writer, total_steps)
logger.close()
PATH = f"{args.ckpt_path}/{args.model_name}_final.pth"
torch.save(model.module.module.state_dict(), PATH)
run_test_eval(evaluator, model, eval_dataloaders, logger.writer, total_steps)
logger.close()
if __name__ == "__main__":
signal.signal(signal.SIGUSR1, sig_handler)
signal.signal(signal.SIGTERM, term_handler)
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", default="cotracker", help="model name")
parser.add_argument("--restore_ckpt", help="path to restore a checkpoint")
@@ -543,17 +498,12 @@ if __name__ == "__main__":
parser.add_argument(
"--batch_size", type=int, default=4, help="batch size used during training."
)
parser.add_argument(
"--num_workers", type=int, default=6, help="number of dataloader workers"
)
parser.add_argument("--num_nodes", type=int, default=1)
parser.add_argument("--num_workers", type=int, default=10, help="number of dataloader workers")
parser.add_argument(
"--mixed_precision", action="store_true", help="use mixed precision"
)
parser.add_argument("--mixed_precision", action="store_true", help="use mixed precision")
parser.add_argument("--lr", type=float, default=0.0005, help="max learning rate.")
parser.add_argument(
"--wdecay", type=float, default=0.00001, help="Weight decay in optimizer."
)
parser.add_argument("--wdecay", type=float, default=0.00001, help="Weight decay in optimizer.")
parser.add_argument(
"--num_steps", type=int, default=200000, help="length of training schedule."
)
@@ -596,13 +546,11 @@ if __name__ == "__main__":
default=4,
help="number of updates to the disparity field in each forward pass.",
)
parser.add_argument(
"--sequence_len", type=int, default=8, help="train sequence length"
)
parser.add_argument("--sequence_len", type=int, default=8, help="train sequence length")
parser.add_argument(
"--eval_datasets",
nargs="+",
default=["things", "badja"],
default=["tapvid_davis_first"],
help="what datasets to use for evaluation",
)
@@ -611,6 +559,12 @@ if __name__ == "__main__":
action="store_true",
help="remove space attention from CoTracker",
)
parser.add_argument(
"--num_virtual_tracks",
type=int,
default=None,
help="stride of the CoTracker feature network",
)
parser.add_argument(
"--dont_use_augs",
action="store_true",
@@ -627,30 +581,6 @@ if __name__ == "__main__":
default=8,
help="length of the CoTracker sliding window",
)
parser.add_argument(
"--updateformer_hidden_size",
type=int,
default=384,
help="hidden dimension of the CoTracker transformer model",
)
parser.add_argument(
"--updateformer_num_heads",
type=int,
default=8,
help="number of heads of the CoTracker transformer model",
)
parser.add_argument(
"--updateformer_space_depth",
type=int,
default=12,
help="number of group attention layers in the CoTracker transformer model",
)
parser.add_argument(
"--updateformer_time_depth",
type=int,
default=12,
help="number of time attention layers in the CoTracker transformer model",
)
parser.add_argument(
"--model_stride",
type=int,
@@ -680,9 +610,9 @@ if __name__ == "__main__":
from pytorch_lightning.strategies import DDPStrategy
Lite(
strategy=DDPStrategy(find_unused_parameters=True),
strategy=DDPStrategy(find_unused_parameters=False),
devices="auto",
accelerator="gpu",
precision=32,
# num_nodes=4,
num_nodes=args.num_nodes,
).run(args)