Update GeMOSA v4

This commit is contained in:
D-X-Y
2021-05-27 17:30:44 +08:00
parent 1ce0b80776
commit b6e11c6360
8 changed files with 147 additions and 39 deletions

View File

@@ -119,10 +119,15 @@ class SyntheticDEnv(data.Dataset):
def __call__(self, timestamp):
dataset = self._data_generator(timestamp, self._num_per_task)
targets = self._oracle_map.noise_call(dataset, timestamp, self._noise)
return torch.Tensor([timestamp]), (
torch.Tensor(dataset),
torch.Tensor(targets),
)
if isinstance(dataset, np.ndarray):
dataset = torch.from_numpy(dataset)
else:
dataset = torch.Tensor(dataset)
if isinstance(targets, np.ndarray):
targets = torch.from_numpy(targets)
else:
targets = torch.Tensor(targets)
return torch.Tensor([timestamp]), (dataset, targets)
def __len__(self):
return len(self._time_generator)