Update the sync data v1
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 #
|
||||
#####################################################
|
||||
import math
|
||||
from .synthetic_utils import TimeStamp
|
||||
from .synthetic_env import EnvSampler
|
||||
from .synthetic_env import SyntheticDEnv
|
||||
from .math_core import LinearFunc
|
||||
from .math_core import DynamicLinearFunc
|
||||
from .math_core import DynamicQuadraticFunc
|
||||
from .math_core import ConstantFunc, ComposedSinFunc
|
||||
from .math_core import GaussianDGenerator
|
||||
|
||||
|
||||
__all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"]
|
||||
@@ -17,42 +18,21 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, versio
|
||||
if version == "v1":
|
||||
mean_generator = ConstantFunc(0)
|
||||
std_generator = ConstantFunc(1)
|
||||
elif version == "v2":
|
||||
mean_generator = ComposedSinFunc()
|
||||
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=1.5)
|
||||
else:
|
||||
raise ValueError("Unknown version: {:}".format(version))
|
||||
dynamic_env = SyntheticDEnv(
|
||||
[mean_generator],
|
||||
[[std_generator]],
|
||||
num_per_task=num_per_task,
|
||||
timestamp_config=dict(
|
||||
min_timestamp=-0.5, max_timestamp=1.5, num=total_timestamp, mode=mode
|
||||
),
|
||||
)
|
||||
if version == "v1":
|
||||
function = DynamicLinearFunc()
|
||||
function_param = dict()
|
||||
function_param[0] = ComposedSinFunc(
|
||||
amplitude_scale=ConstantFunc(3.0),
|
||||
num_sin_phase=9,
|
||||
sin_speed_use_power=False,
|
||||
data_generator = GaussianDGenerator(
|
||||
[mean_generator], [[std_generator]], (-2, 2)
|
||||
)
|
||||
function_param[1] = ConstantFunc(constant=0.9)
|
||||
elif version == "v2":
|
||||
function = DynamicQuadraticFunc()
|
||||
function_param = dict()
|
||||
function_param[0] = ComposedSinFunc(
|
||||
num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0
|
||||
time_generator = TimeStamp(
|
||||
min_timestamp=0, max_timestamp=math.pi * 6, num=total_timestamp, mode=mode
|
||||
)
|
||||
function_param[1] = ConstantFunc(constant=0.9)
|
||||
function_param[2] = ComposedSinFunc(
|
||||
num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9
|
||||
oracle_map = DynamicLinearFunc(
|
||||
params={
|
||||
0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}),
|
||||
1: ComposedSinFunc(params={0: 1.5, 1: 0.4, 2: 2.2}),
|
||||
}
|
||||
)
|
||||
dynamic_env = SyntheticDEnv(
|
||||
data_generator, oracle_map, time_generator, num_per_task
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown version: {:}".format(version))
|
||||
|
||||
function.set(function_param)
|
||||
# dynamic_env.set_oracle_map(copy.deepcopy(function))
|
||||
dynamic_env.set_oracle_map(function)
|
||||
return dynamic_env
|
||||
|
Reference in New Issue
Block a user