Correct the codes
This commit is contained in:
@@ -22,12 +22,12 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, versio
|
||||
[mean_generator], [[std_generator]], (-2, 2)
|
||||
)
|
||||
time_generator = TimeStamp(
|
||||
min_timestamp=0, max_timestamp=math.pi * 6, num=total_timestamp, mode=mode
|
||||
min_timestamp=0, max_timestamp=math.pi * 8, num=total_timestamp, mode=mode
|
||||
)
|
||||
oracle_map = DynamicLinearFunc(
|
||||
params={
|
||||
0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}),
|
||||
1: ComposedSinFunc(params={0: 1.5, 1: 0.4, 2: 2.2}),
|
||||
1: ComposedSinFunc(params={0: 1.5, 1: 0.6, 2: 1.8}),
|
||||
}
|
||||
)
|
||||
dynamic_env = SyntheticDEnv(
|
||||
|
@@ -28,7 +28,7 @@ class UnifiedSplit:
|
||||
self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid]
|
||||
elif mode.lower() in ("test", "testing"):
|
||||
self._indexes = all_indexes[num_of_train + num_of_valid :]
|
||||
elif mode.lower() in ("trainval", "trainvalidation"):
|
||||
elif mode.lower() in ("trainval", "trainvalid", "trainvalidation"):
|
||||
self._indexes = all_indexes[: num_of_train + num_of_valid]
|
||||
else:
|
||||
raise ValueError("Unkonwn mode of {:}".format(mode))
|
||||
|
Reference in New Issue
Block a user