Update GeMOSA v4
This commit is contained in:
@@ -119,10 +119,15 @@ class SyntheticDEnv(data.Dataset):
|
||||
def __call__(self, timestamp):
|
||||
dataset = self._data_generator(timestamp, self._num_per_task)
|
||||
targets = self._oracle_map.noise_call(dataset, timestamp, self._noise)
|
||||
return torch.Tensor([timestamp]), (
|
||||
torch.Tensor(dataset),
|
||||
torch.Tensor(targets),
|
||||
)
|
||||
if isinstance(dataset, np.ndarray):
|
||||
dataset = torch.from_numpy(dataset)
|
||||
else:
|
||||
dataset = torch.Tensor(dataset)
|
||||
if isinstance(targets, np.ndarray):
|
||||
targets = torch.from_numpy(targets)
|
||||
else:
|
||||
targets = torch.Tensor(targets)
|
||||
return torch.Tensor([timestamp]), (dataset, targets)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._time_generator)
|
||||
|
Reference in New Issue
Block a user