Fix bugs
This commit is contained in:
@@ -63,7 +63,7 @@ class SyntheticDEnv(data.Dataset):
|
||||
dataset = np.random.multivariate_normal(
|
||||
mean_list, cov_matrix, size=self._num_per_task
|
||||
)
|
||||
return index, torch.Tensor(dataset)
|
||||
return timestamp, torch.Tensor(dataset)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._timestamp_generator)
|
||||
|
@@ -8,9 +8,10 @@ from .synthetic_env import SyntheticDEnv
|
||||
|
||||
|
||||
def create_example_v1(
|
||||
timestamp_config=dict(num=100, min_timestamp=0.0, max_timestamp=1.0),
|
||||
timestamp_config=None,
|
||||
num_per_task=5000,
|
||||
):
|
||||
# timestamp_config=dict(num=100, min_timestamp=0.0, max_timestamp=1.0),
|
||||
mean_generator = ComposedSinFunc()
|
||||
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5)
|
||||
|
||||
|
Reference in New Issue
Block a user