Update xmisc.scheduler/sampler

This commit is contained in:
D-X-Y
2021-06-11 11:46:18 +08:00
parent 9bf0fa5f04
commit 48163c792c
17 changed files with 807 additions and 201 deletions

View File

@@ -46,8 +46,7 @@ def main(args):
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=args.batch_size,
shuffle=True,
batch_sampler=xmisc.BatchSampler(train_data, args.batch_size, args.steps),
num_workers=args.workers,
pin_memory=True,
)
@@ -57,6 +56,7 @@ def main(args):
shuffle=False,
num_workers=args.workers,
pin_memory=True,
drop_last=False,
)
logger.log("The training loader: {:}".format(train_loader))
@@ -73,6 +73,9 @@ def main(args):
logger.log("The loss is {:}".format(loss))
model, loss = torch.nn.DataParallel(model).cuda(), loss.cuda()
scheduler = xmisc.LRMultiplier(
optimizer, xmisc.get_scheduler(args.scheduler, args.lr), args.steps
)
import pdb
@@ -241,10 +244,11 @@ if __name__ == "__main__":
"--valid_data_config", type=str, help="The validation dataset config path."
)
parser.add_argument("--data_path", type=str, help="The path to the dataset.")
parser.add_argument("--algorithm", type=str, help="The algorithm.")
# Optimization options
parser.add_argument("--lr", type=float, help="The learning rate")
parser.add_argument("--weight_decay", type=float, help="The weight decay")
parser.add_argument("--scheduler", type=str, help="The scheduler indicator.")
parser.add_argument("--steps", type=int, help="The total number of steps.")
parser.add_argument("--batch_size", type=int, default=2, help="The batch size.")
parser.add_argument("--workers", type=int, default=4, help="The number of workers")
# Random Seed