add autodl
This commit is contained in:
32
AutoDL-Projects/xautodl/xmisc/sampler_utils.py
Normal file
32
AutoDL-Projects/xautodl/xmisc/sampler_utils.py
Normal file
@@ -0,0 +1,32 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 #
|
||||
#####################################################
|
||||
import random
|
||||
|
||||
|
||||
class BatchSampler:
|
||||
"""A batch sampler used for single machine training."""
|
||||
|
||||
def __init__(self, dataset, batch, steps):
|
||||
self._num_per_epoch = len(dataset)
|
||||
self._iter_per_epoch = self._num_per_epoch // batch
|
||||
self._steps = steps
|
||||
self._batch = batch
|
||||
if self._num_per_epoch < self._batch:
|
||||
raise ValueError(
|
||||
"The dataset size must be larger than batch={:}".format(batch)
|
||||
)
|
||||
self._indexes = list(range(self._num_per_epoch))
|
||||
|
||||
def __iter__(self):
|
||||
"""
|
||||
yield a batch of indexes using random sampling
|
||||
"""
|
||||
for i in range(self._steps):
|
||||
if i % self._iter_per_epoch == 0:
|
||||
random.shuffle(self._indexes)
|
||||
j = i % self._iter_per_epoch
|
||||
yield self._indexes[j * self._batch : (j + 1) * self._batch]
|
||||
|
||||
def __len__(self):
|
||||
return self._steps
|
Reference in New Issue
Block a user