Finalize example vis codes
This commit is contained in:
@@ -41,6 +41,11 @@ class SyntheticDEnv(data.Dataset):
|
||||
self._mean_functors = mean_functors
|
||||
self._cov_functors = cov_functors
|
||||
|
||||
self._oracle_map = None
|
||||
|
||||
def set_oracle_map(self, functor):
|
||||
self._oracle_map = functor
|
||||
|
||||
def __iter__(self):
|
||||
self._iter_num = 0
|
||||
return self
|
||||
@@ -63,7 +68,11 @@ class SyntheticDEnv(data.Dataset):
|
||||
dataset = np.random.multivariate_normal(
|
||||
mean_list, cov_matrix, size=self._num_per_task
|
||||
)
|
||||
return timestamp, torch.Tensor(dataset)
|
||||
if self._oracle_map is None:
|
||||
return timestamp, torch.Tensor(dataset)
|
||||
else:
|
||||
targets = self._oracle_map.noise_call(dataset, timestamp)
|
||||
return timestamp, (torch.Tensor(dataset), torch.Tensor(targets))
|
||||
|
||||
def __len__(self):
|
||||
return len(self._timestamp_generator)
|
||||
|
Reference in New Issue
Block a user