Update xmisc.scheduler/sampler
This commit is contained in:
@@ -46,8 +46,7 @@ def main(args):
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_data,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
batch_sampler=xmisc.BatchSampler(train_data, args.batch_size, args.steps),
|
||||
num_workers=args.workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
@@ -57,6 +56,7 @@ def main(args):
|
||||
shuffle=False,
|
||||
num_workers=args.workers,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
logger.log("The training loader: {:}".format(train_loader))
|
||||
@@ -73,6 +73,9 @@ def main(args):
|
||||
logger.log("The loss is {:}".format(loss))
|
||||
|
||||
model, loss = torch.nn.DataParallel(model).cuda(), loss.cuda()
|
||||
scheduler = xmisc.LRMultiplier(
|
||||
optimizer, xmisc.get_scheduler(args.scheduler, args.lr), args.steps
|
||||
)
|
||||
|
||||
import pdb
|
||||
|
||||
@@ -241,10 +244,11 @@ if __name__ == "__main__":
|
||||
"--valid_data_config", type=str, help="The validation dataset config path."
|
||||
)
|
||||
parser.add_argument("--data_path", type=str, help="The path to the dataset.")
|
||||
parser.add_argument("--algorithm", type=str, help="The algorithm.")
|
||||
# Optimization options
|
||||
parser.add_argument("--lr", type=float, help="The learning rate")
|
||||
parser.add_argument("--weight_decay", type=float, help="The weight decay")
|
||||
parser.add_argument("--scheduler", type=str, help="The scheduler indicator.")
|
||||
parser.add_argument("--steps", type=int, help="The total number of steps.")
|
||||
parser.add_argument("--batch_size", type=int, default=2, help="The batch size.")
|
||||
parser.add_argument("--workers", type=int, default=4, help="The number of workers")
|
||||
# Random Seed
|
||||
|
Reference in New Issue
Block a user