Finalize example vis codes

This commit is contained in:
D-X-Y
2021-04-27 20:09:37 +08:00
parent 77cab08d60
commit 5eb18e8adb
8 changed files with 98 additions and 61 deletions

View File

@@ -41,6 +41,11 @@ class SyntheticDEnv(data.Dataset):
self._mean_functors = mean_functors
self._cov_functors = cov_functors
self._oracle_map = None
def set_oracle_map(self, functor):
self._oracle_map = functor
def __iter__(self):
self._iter_num = 0
return self
@@ -63,7 +68,11 @@ class SyntheticDEnv(data.Dataset):
dataset = np.random.multivariate_normal(
mean_list, cov_matrix, size=self._num_per_task
)
return timestamp, torch.Tensor(dataset)
if self._oracle_map is None:
return timestamp, torch.Tensor(dataset)
else:
targets = self._oracle_map.noise_call(dataset, timestamp)
return timestamp, (torch.Tensor(dataset), torch.Tensor(targets))
def __len__(self):
return len(self._timestamp_generator)