This commit is contained in:
D-X-Y
2021-04-26 21:44:03 +08:00
parent 8358d71cdf
commit d3371296a7
10 changed files with 270 additions and 264 deletions

View File

@@ -63,7 +63,7 @@ class SyntheticDEnv(data.Dataset):
dataset = np.random.multivariate_normal(
mean_list, cov_matrix, size=self._num_per_task
)
return index, torch.Tensor(dataset)
return timestamp, torch.Tensor(dataset)
def __len__(self):
return len(self._timestamp_generator)

View File

@@ -8,9 +8,10 @@ from .synthetic_env import SyntheticDEnv
def create_example_v1(
timestamp_config=dict(num=100, min_timestamp=0.0, max_timestamp=1.0),
timestamp_config=None,
num_per_task=5000,
):
# timestamp_config=dict(num=100, min_timestamp=0.0, max_timestamp=1.0),
mean_generator = ComposedSinFunc()
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5)