diff --git a/train.py b/train.py index ebb7d08..57b9f87 100644 --- a/train.py +++ b/train.py @@ -36,21 +36,6 @@ from cotracker.datasets.utils import collate_fn, collate_fn_train, dataclass_to_ from cotracker.models.core.cotracker.losses import sequence_loss, balanced_ce_loss -# define the handler function -# for training on a slurm cluster -def sig_handler(signum, frame): - print("caught signal", signum) - print(socket.gethostname(), "USR1 signal caught.") - # do other stuff to cleanup here - print("requeuing job " + os.environ["SLURM_JOB_ID"]) - os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"]) - sys.exit(-1) - - -def term_handler(signum, frame): - print("bypassing sigterm", flush=True) - - def fetch_optimizer(args, model): """Create the optimizer and learning rate scheduler""" optimizer = optim.AdamW( @@ -302,9 +287,7 @@ class Lite(LightningLite): eval_dataloaders.append(("fastcapture", eval_dataloader_fastcapture)) if "tapvid_davis_first" in args.eval_datasets: - data_root = os.path.join( - args.dataset_root, "/tapvid_davis/tapvid_davis.pkl" - ) + data_root = os.path.join(args.dataset_root, "tapvid_davis/tapvid_davis.pkl") eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root) eval_dataloader_tapvid_davis = torch.utils.data.DataLoader( eval_dataset, @@ -551,17 +534,15 @@ class Lite(LightningLite): if __name__ == "__main__": - signal.signal(signal.SIGUSR1, sig_handler) - signal.signal(signal.SIGTERM, term_handler) parser = argparse.ArgumentParser() parser.add_argument("--model_name", default="cotracker", help="model name") - parser.add_argument("--restore_ckpt", help="restore checkpoint") - parser.add_argument("--ckpt_path", help="restore checkpoint") + parser.add_argument("--restore_ckpt", help="path to restore a checkpoint") + parser.add_argument("--ckpt_path", help="path to save checkpoints") parser.add_argument( "--batch_size", type=int, default=4, help="batch size used during training." ) parser.add_argument( - "--num_workers", type=int, default=6, help="left right consistency loss" + "--num_workers", type=int, default=6, help="number of dataloader workers" ) parser.add_argument( @@ -578,20 +559,34 @@ if __name__ == "__main__": "--evaluate_every_n_epoch", type=int, default=1, - help="number of flow-field updates during validation forward pass", + help="evaluate during training after every n epochs, after every epoch by default", ) parser.add_argument( "--save_every_n_epoch", type=int, default=1, - help="number of flow-field updates during validation forward pass", + help="save checkpoints during training after every n epochs, after every epoch by default", ) parser.add_argument( - "--validate_at_start", action="store_true", help="use mixed precision" + "--validate_at_start", + action="store_true", + help="whether to run evaluation before training starts", + ) + parser.add_argument( + "--save_freq", + type=int, + default=100, + help="frequency of trajectory visualization during training", + ) + parser.add_argument( + "--traj_per_sample", + type=int, + default=768, + help="the number of trajectories to sample for training", + ) + parser.add_argument( + "--dataset_root", type=str, help="path lo all the datasets (train and eval)" ) - parser.add_argument("--save_freq", type=int, default=100, help="save_freq") - parser.add_argument("--traj_per_sample", type=int, default=768, help="save_freq") - parser.add_argument("--dataset_root", type=str, help="path lo all the datasets") parser.add_argument( "--train_iters", @@ -605,49 +600,75 @@ if __name__ == "__main__": parser.add_argument( "--eval_datasets", nargs="+", - default=["things", "badja", "fastcapture"], - help="eval datasets.", + default=["things", "badja"], + help="what datasets to use for evaluation", ) parser.add_argument( - "--remove_space_attn", action="store_true", help="use mixed precision" + "--remove_space_attn", + action="store_true", + help="remove space attention from CoTracker", ) parser.add_argument( - "--dont_use_augs", action="store_true", help="use mixed precision" + "--dont_use_augs", + action="store_true", + help="don't apply augmentations during training", ) parser.add_argument( - "--sample_vis_1st_frame", action="store_true", help="use mixed precision" + "--sample_vis_1st_frame", + action="store_true", + help="only sample trajectories with points visible on the first frame", ) parser.add_argument( - "--sliding_window_len", type=int, default=8, help="use mixed precision" + "--sliding_window_len", + type=int, + default=8, + help="length of the CoTracker sliding window", ) parser.add_argument( - "--updateformer_hidden_size", type=int, default=384, help="use mixed precision" + "--updateformer_hidden_size", + type=int, + default=384, + help="hidden dimension of the CoTracker transformer model", ) parser.add_argument( - "--updateformer_num_heads", type=int, default=8, help="use mixed precision" + "--updateformer_num_heads", + type=int, + default=8, + help="number of heads of the CoTracker transformer model", ) parser.add_argument( - "--updateformer_space_depth", type=int, default=12, help="use mixed precision" + "--updateformer_space_depth", + type=int, + default=12, + help="number of group attention layers in the CoTracker transformer model", ) parser.add_argument( - "--updateformer_time_depth", type=int, default=12, help="use mixed precision" + "--updateformer_time_depth", + type=int, + default=12, + help="number of time attention layers in the CoTracker transformer model", ) parser.add_argument( - "--model_stride", type=int, default=8, help="use mixed precision" + "--model_stride", + type=int, + default=8, + help="stride of the CoTracker feature network", ) parser.add_argument( "--crop_size", type=int, nargs="+", default=[384, 512], - help="use mixed precision", + help="crop videos to this resolution during training", ) parser.add_argument( - "--eval_max_seq_len", type=int, default=1000, help="use mixed precision" + "--eval_max_seq_len", + type=int, + default=1000, + help="maximum length of evaluation videos", ) args = parser.parse_args() - logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", @@ -661,5 +682,5 @@ if __name__ == "__main__": devices="auto", accelerator="gpu", precision=32, - num_nodes=4, + # num_nodes=4, ).run(args)