Re-organize GeMOSA

This commit is contained in:
D-X-Y
2021-05-27 15:44:01 +08:00
parent 8961215416
commit 6da60664f5
10 changed files with 354 additions and 350 deletions

View File

@@ -1,92 +0,0 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import math
import abc
import copy
import numpy as np
from typing import Optional
import torch
import torch.utils.data as data
from .math_base_funcs import FitFunc
from .math_base_funcs import QuadraticFunc
from .math_base_funcs import QuarticFunc
class ConstantFunc(FitFunc):
"""The constant function: f(x) = c."""
def __init__(self, constant=None, xstr="x"):
param = dict()
param[0] = constant
super(ConstantFunc, self).__init__(0, None, param, xstr)
def __call__(self, x):
self.check_valid()
return self._params[0]
def fit(self, **kwargs):
raise NotImplementedError
def _getitem(self, x, weights):
raise NotImplementedError
def __repr__(self):
return "{name}({a})".format(name=self.__class__.__name__, a=self._params[0])
class ComposedSinFunc(FitFunc):
"""The composed sin function that outputs:
f(x) = a * sin( b*x ) + c
"""
def __init__(self, params, xstr="x"):
super(ComposedSinFunc, self).__init__(3, None, params, xstr)
def __call__(self, x):
self.check_valid()
a = self._params[0]
b = self._params[1]
c = self._params[2]
return a * math.sin(b * x) + c
def _getitem(self, x, weights):
raise NotImplementedError
def __repr__(self):
return "{name}({a} * sin({b} * {x}) + {c})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
x=self.xstr,
)
class ComposedCosFunc(FitFunc):
"""The composed sin function that outputs:
f(x) = a * cos( b*x ) + c
"""
def __init__(self, params, xstr="x"):
super(ComposedCosFunc, self).__init__(3, None, params, xstr)
def __call__(self, x):
self.check_valid()
a = self._params[0]
b = self._params[1]
c = self._params[2]
return a * math.cos(b * x) + c
def _getitem(self, x, weights):
raise NotImplementedError
def __repr__(self):
return "{name}({a} * sin({b} * {x}) + {c})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
x=self.xstr,
)

View File

@@ -5,34 +5,33 @@ import math
import abc
import copy
import numpy as np
import torch
class FitFunc(abc.ABC):
"""The fit function that outputs f(x) = a * x^2 + b * x + c."""
class MathFunc(abc.ABC):
"""The math function -- a virtual class defining some APIs."""
def __init__(self, freedom: int, list_of_points=None, params=None, xstr="x"):
def __init__(self, freedom: int, params=None, xstr="x"):
# initialize as empty
self._params = dict()
for i in range(freedom):
self._params[i] = None
self._freedom = freedom
if list_of_points is not None and params is not None:
raise ValueError("list_of_points and params can not be set simultaneously")
if list_of_points is not None:
self.fit(list_of_points=list_of_points)
if params is not None:
self.set(params)
self._xstr = str(xstr)
self._skip_check = True
def set(self, params):
self._params = copy.deepcopy(params)
for key in range(self._freedom):
param = copy.deepcopy(params[key])
self._params[key] = param
def check_valid(self):
# for key, value in self._params.items():
for key in range(self._freedom):
value = self._params[key]
if value is None:
raise ValueError("The {:} is None".format(key))
if not self._skip_check:
for key in range(self._freedom):
value = self._params[key]
if value is None:
raise ValueError("The {:} is None".format(key))
@property
def xstr(self):
@@ -45,7 +44,8 @@ class FitFunc(abc.ABC):
def __call__(self, x):
raise NotImplementedError
def noise_call(self, x, std=0.1):
@abc.abstractmethod
def noise_call(self, x, std):
clean_y = self.__call__(x)
if isinstance(clean_y, np.ndarray):
noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape)
@@ -53,169 +53,7 @@ class FitFunc(abc.ABC):
raise ValueError("Unkonwn type: {:}".format(type(clean_y)))
return noise_y
@abc.abstractmethod
def _getitem(self, x):
raise NotImplementedError
def fit(self, **kwargs):
list_of_points = kwargs["list_of_points"]
max_iter, lr_max, verbose = (
kwargs.get("max_iter", 900),
kwargs.get("lr_max", 1.0),
kwargs.get("verbose", False),
)
with torch.no_grad():
data = torch.Tensor(list_of_points).type(torch.float32)
assert data.ndim == 2 and data.size(1) == 2, "Invalid shape : {:}".format(
data.shape
)
x, y = data[:, 0], data[:, 1]
weights = torch.nn.Parameter(torch.Tensor(self._freedom))
torch.nn.init.normal_(weights, mean=0.0, std=1.0)
optimizer = torch.optim.Adam([weights], lr=lr_max, amsgrad=True)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=[
int(max_iter * 0.25),
int(max_iter * 0.5),
int(max_iter * 0.75),
],
gamma=0.1,
)
if verbose:
print("The optimizer: {:}".format(optimizer))
best_loss = None
for _iter in range(max_iter):
y_hat = self._getitem(x, weights)
loss = torch.mean(torch.abs(y - y_hat))
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
if verbose:
print(
"In the fit, loss at the {:02d}/{:02d}-th iter is {:}".format(
_iter, max_iter, loss.item()
)
)
# Update the params
if best_loss is None or best_loss > loss.item():
best_loss = loss.item()
for i in range(self._freedom):
self._params[i] = weights[i].item()
def __repr__(self):
return "{name}(freedom={freedom})".format(
name=self.__class__.__name__, freedom=freedom
)
class LinearFunc(FitFunc):
"""The linear function that outputs f(x) = a * x + b."""
def __init__(self, list_of_points=None, params=None, xstr="x"):
super(LinearFunc, self).__init__(2, list_of_points, params, xstr)
def __call__(self, x):
self.check_valid()
return self._params[0] * x + self._params[1]
def _getitem(self, x, weights):
return weights[0] * x + weights[1]
def __repr__(self):
return "{name}({a} * {x} + {b})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
x=self.xstr,
)
class QuadraticFunc(FitFunc):
"""The quadratic function that outputs f(x) = a * x^2 + b * x + c."""
def __init__(self, list_of_points=None, params=None, xstr="x"):
super(QuadraticFunc, self).__init__(3, list_of_points, params, xstr)
def __call__(self, x):
self.check_valid()
return self._params[0] * x * x + self._params[1] * x + self._params[2]
def _getitem(self, x, weights):
return weights[0] * x * x + weights[1] * x + weights[2]
def __repr__(self):
return "{name}({a} * {x}^2 + {b} * {x} + {c})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
x=self.xstr,
)
class CubicFunc(FitFunc):
"""The cubic function that outputs f(x) = a * x^3 + b * x^2 + c * x + d."""
def __init__(self, list_of_points=None):
super(CubicFunc, self).__init__(4, list_of_points)
def __call__(self, x):
self.check_valid()
return (
self._params[0] * x ** 3
+ self._params[1] * x ** 2
+ self._params[2] * x
+ self._params[3]
)
def _getitem(self, x, weights):
return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3]
def __repr__(self):
return "{name}({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
d=self._params[3],
x=self.xstr,
)
class QuarticFunc(FitFunc):
"""The quartic function that outputs f(x) = a * x^4 + b * x^3 + c * x^2 + d * x + e."""
def __init__(self, list_of_points=None):
super(QuarticFunc, self).__init__(5, list_of_points)
def __call__(self, x):
self.check_valid()
return (
self._params[0] * x ** 4
+ self._params[1] * x ** 3
+ self._params[2] * x ** 2
+ self._params[3] * x
+ self._params[4]
)
def _getitem(self, x, weights):
return (
weights[0] * x ** 4
+ weights[1] * x ** 3
+ weights[2] * x ** 2
+ weights[3] * x
+ weights[4]
)
def __repr__(self):
return "{name}({a} * x^4 + {b} * x^3 + {c} * x^2 + {d} * x + {e})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
d=self._params[3],
e=self._params[3],
)

View File

@@ -1,10 +1,14 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 #
#####################################################
from .math_base_funcs import LinearFunc, QuadraticFunc, CubicFunc, QuarticFunc
from .math_dynamic_funcs import DynamicLinearFunc
from .math_dynamic_funcs import DynamicQuadraticFunc
from .math_dynamic_funcs import DynamicSinQuadraticFunc
from .math_adv_funcs import ConstantFunc
from .math_adv_funcs import ComposedSinFunc, ComposedCosFunc
from .math_static_funcs import (
LinearSFunc,
QuadraticSFunc,
CubicSFunc,
QuarticSFunc,
ConstantFunc,
ComposedSinSFunc,
ComposedCosSFunc,
)
from .math_dynamic_funcs import LinearDFunc, QuadraticDFunc, SinQuadraticDFunc
from .math_dynamic_generator import GaussianDGenerator

View File

@@ -6,23 +6,17 @@ import abc
import copy
import numpy as np
from .math_base_funcs import FitFunc
from .math_base_funcs import MathFunc
class DynamicFunc(FitFunc):
"""The dynamic quadratic function, where each param is a function."""
class DynamicFunc(MathFunc):
"""The dynamic function, where each param is a function."""
def __init__(self, freedom: int, params=None, xstr="x"):
if params is not None:
for param in params:
param.reset_xstr("t") if isinstance(param, FitFunc) else None
super(DynamicFunc, self).__init__(freedom, None, params, xstr)
def __call__(self, x, timestamp):
raise NotImplementedError
def _getitem(self, x, weights):
raise NotImplementedError
for key, param in params.items():
param.reset_xstr("t") if isinstance(param, MathFunc) else None
super(DynamicFunc, self).__init__(freedom, params, xstr)
def noise_call(self, x, timestamp, std):
clean_y = self.__call__(x, timestamp)
@@ -33,13 +27,13 @@ class DynamicFunc(FitFunc):
return noise_y
class DynamicLinearFunc(DynamicFunc):
class LinearDFunc(DynamicFunc):
"""The dynamic linear function that outputs f(x) = a * x + b.
The a and b is a function of timestamp.
"""
def __init__(self, params=None, xstr="x"):
super(DynamicLinearFunc, self).__init__(3, params, xstr)
def __init__(self, params, xstr="x"):
super(LinearDFunc, self).__init__(2, params, xstr)
def __call__(self, x, timestamp):
a = self._params[0](timestamp)
@@ -57,18 +51,15 @@ class DynamicLinearFunc(DynamicFunc):
)
class DynamicQuadraticFunc(DynamicFunc):
class QuadraticDFunc(DynamicFunc):
"""The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c.
The a, b, and c is a function of timestamp.
"""
def __init__(self, params=None):
super(DynamicQuadraticFunc, self).__init__(3, params)
def __init__(self, params, xstr="x"):
super(QuadraticDFunc, self).__init__(3, params)
def __call__(
self,
x,
):
def __call__(self, x, timestamp):
self.check_valid()
a = self._params[0](timestamp)
b = self._params[1](timestamp)
@@ -78,38 +69,37 @@ class DynamicQuadraticFunc(DynamicFunc):
return a * x * x + b * x + c
def __repr__(self):
return "{name}({a} * x^2 + {b} * x + {c})".format(
return "{name}({a} * {x}^2 + {b} * {x} + {c})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
x=self.xstr,
)
class DynamicSinQuadraticFunc(DynamicFunc):
class SinQuadraticDFunc(DynamicFunc):
"""The dynamic quadratic function that outputs f(x) = sin(a * x^2 + b * x + c).
The a, b, and c is a function of timestamp.
"""
def __init__(self, params=None):
super(DynamicSinQuadraticFunc, self).__init__(3, params)
super(SinQuadraticDFunc, self).__init__(3, params)
def __call__(
self,
x,
):
def __call__(self, x, timestamp):
self.check_valid()
a = self._params[0](timestamp)
b = self._params[1](timestamp)
c = self._params[2](timestamp)
convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x
a, b, c = convert_fn(a), convert_fn(b), convert_fn(c)
return math.sin(a * x * x + b * x + c)
return np.sin(a * x * x + b * x + c)
def __repr__(self):
return "{name}({a} * x^2 + {b} * x + {c})".format(
return "{name}({a} * {x}^2 + {b} * {x} + {c})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
x=self.xstr,
)

View File

@@ -0,0 +1,225 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import math
import abc
import copy
import numpy as np
from .math_base_funcs import MathFunc
class StaticFunc(MathFunc):
"""The fit function that outputs f(x) = a * x^2 + b * x + c."""
def __init__(self, freedom: int, params=None, xstr="x"):
super(StaticFunc, self).__init__(freedom, params, xstr)
@abc.abstractmethod
def __call__(self, x):
raise NotImplementedError
def noise_call(self, x, std):
clean_y = self.__call__(x)
if isinstance(clean_y, np.ndarray):
noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape)
else:
raise ValueError("Unkonwn type: {:}".format(type(clean_y)))
return noise_y
def __repr__(self):
return "{name}(freedom={freedom})".format(
name=self.__class__.__name__, freedom=freedom
)
class LinearSFunc(StaticFunc):
"""The linear function that outputs f(x) = a * x + b."""
def __init__(self, params=None, xstr="x"):
super(LinearSFunc, self).__init__(2, params, xstr)
def __call__(self, x):
self.check_valid()
return self._params[0] * x + self._params[1]
def _getitem(self, x, weights):
return weights[0] * x + weights[1]
def __repr__(self):
return "{name}({a} * {x} + {b})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
x=self.xstr,
)
class QuadraticSFunc(StaticFunc):
"""The quadratic function that outputs f(x) = a * x^2 + b * x + c."""
def __init__(self, params=None, xstr="x"):
super(QuadraticSFunc, self).__init__(3, params, xstr)
def __call__(self, x):
self.check_valid()
return self._params[0] * x * x + self._params[1] * x + self._params[2]
def _getitem(self, x, weights):
return weights[0] * x * x + weights[1] * x + weights[2]
def __repr__(self):
return "{name}({a} * {x}^2 + {b} * {x} + {c})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
x=self.xstr,
)
class CubicSFunc(StaticFunc):
"""The cubic function that outputs f(x) = a * x^3 + b * x^2 + c * x + d."""
def __init__(self, params=None, xstr="x"):
super(CubicSFunc, self).__init__(4, params, xstr)
def __call__(self, x):
self.check_valid()
return (
self._params[0] * x ** 3
+ self._params[1] * x ** 2
+ self._params[2] * x
+ self._params[3]
)
def _getitem(self, x, weights):
return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3]
def __repr__(self):
return "{name}({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
d=self._params[3],
x=self.xstr,
)
class QuarticSFunc(StaticFunc):
"""The quartic function that outputs f(x) = a * x^4 + b * x^3 + c * x^2 + d * x + e."""
def __init__(self, params=None, xstr="x"):
super(QuarticSFunc, self).__init__(5, params, xstr)
def __call__(self, x):
self.check_valid()
return (
self._params[0] * x ** 4
+ self._params[1] * x ** 3
+ self._params[2] * x ** 2
+ self._params[3] * x
+ self._params[4]
)
def _getitem(self, x, weights):
return (
weights[0] * x ** 4
+ weights[1] * x ** 3
+ weights[2] * x ** 2
+ weights[3] * x
+ weights[4]
)
def __repr__(self):
return (
"{name}({a} * {x}^4 + {b} * {x}^3 + {c} * {x}^2 + {d} * {x} + {e})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
d=self._params[3],
e=self._params[3],
x=self.xstr,
)
)
### advanced functions
class ConstantFunc(StaticFunc):
"""The constant function: f(x) = c."""
def __init__(self, constant, xstr="x"):
super(ConstantFunc, self).__init__(1, {0: constant}, xstr)
def __call__(self, x):
self.check_valid()
return self._params[0]
def fit(self, **kwargs):
raise NotImplementedError
def _getitem(self, x, weights):
raise NotImplementedError
def __repr__(self):
return "{name}({a})".format(name=self.__class__.__name__, a=self._params[0])
class ComposedSinSFunc(StaticFunc):
"""The composed sin function that outputs:
f(x) = a * sin( b*x ) + c
"""
def __init__(self, params, xstr="x"):
super(ComposedSinSFunc, self).__init__(3, params, xstr)
def __call__(self, x):
self.check_valid()
a = self._params[0]
b = self._params[1]
c = self._params[2]
return a * math.sin(b * x) + c
def _getitem(self, x, weights):
raise NotImplementedError
def __repr__(self):
return "{name}({a} * sin({b} * {x}) + {c})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
x=self.xstr,
)
class ComposedCosSFunc(StaticFunc):
"""The composed sin function that outputs:
f(x) = a * cos( b*x ) + c
"""
def __init__(self, params, xstr="x"):
super(ComposedCosSFunc, self).__init__(3, params, xstr)
def __call__(self, x):
self.check_valid()
a = self._params[0]
b = self._params[1]
c = self._params[2]
return a * math.cos(b * x) + c
def _getitem(self, x, weights):
raise NotImplementedError
def __repr__(self):
return "{name}({a} * sin({b} * {x}) + {c})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
x=self.xstr,
)

View File

@@ -1,13 +1,13 @@
import math
from .synthetic_utils import TimeStamp
from .synthetic_env import SyntheticDEnv
from .math_core import LinearFunc
from .math_core import DynamicLinearFunc
from .math_core import DynamicQuadraticFunc, DynamicSinQuadraticFunc
from .math_core import LinearSFunc
from .math_core import LinearDFunc
from .math_core import QuadraticDFunc, SinQuadraticDFunc
from .math_core import (
ConstantFunc,
ComposedSinFunc as SinFunc,
ComposedCosFunc as CosFunc,
ComposedSinSFunc as SinFunc,
ComposedCosSFunc as CosFunc,
)
from .math_core import GaussianDGenerator
@@ -17,7 +17,7 @@ __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"]
def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, version="v1"):
max_time = math.pi * 10
if version == "v1":
if version.lower() == "v1":
mean_generator = ConstantFunc(0)
std_generator = ConstantFunc(1)
data_generator = GaussianDGenerator(
@@ -26,7 +26,7 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio
time_generator = TimeStamp(
min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode
)
oracle_map = DynamicLinearFunc(
oracle_map = LinearDFunc(
params={
0: SinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}), # 2 sin(t) + 2.2
1: SinFunc(params={0: 1.5, 1: 0.6, 2: 1.8}), # 1.5 sin(0.6t) + 1.8
@@ -35,7 +35,8 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio
dynamic_env = SyntheticDEnv(
data_generator, oracle_map, time_generator, num_per_task
)
elif version == "v2":
dynamic_env.set_regression()
elif version.lower() == "v2":
mean_generator = ConstantFunc(0)
std_generator = ConstantFunc(1)
data_generator = GaussianDGenerator(
@@ -44,16 +45,17 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio
time_generator = TimeStamp(
min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode
)
oracle_map = DynamicQuadraticFunc(
oracle_map = QuadraticDFunc(
params={
0: LinearFunc(params={0: 0.1, 1: 0}), # 0.1 * t
1: SinFunc(params={0: 1, 1: 1, 2: 0}), # sin(t)
2: ConstantFunc(0),
0: LinearSFunc(params={0: 0.1, 1: 0}), # 0.1 * t
1: ConstantFunc(0),
2: CosFunc(params={0: 4.0, 1: 10, 2: 0}), # 4 * cos(10 * t)
}
)
dynamic_env = SyntheticDEnv(
data_generator, oracle_map, time_generator, num_per_task
)
dynamic_env.set_regression()
elif version.lower() == "v3":
mean_generator = SinFunc(params={0: 1, 1: 1, 2: 0}) # sin(t)
std_generator = CosFunc(params={0: 0.5, 1: 1, 2: 1}) # 0.5 cos(t) + 1
@@ -63,7 +65,7 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio
time_generator = TimeStamp(
min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode
)
oracle_map = DynamicSinQuadraticFunc(
oracle_map = SinQuadraticDFunc(
params={
0: CosFunc(params={0: 0.5, 1: 1, 2: 1}), # 0.5 cos(t) + 1
1: SinFunc(params={0: 1, 1: 1, 2: 0}), # sin(t)
@@ -73,6 +75,9 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio
dynamic_env = SyntheticDEnv(
data_generator, oracle_map, time_generator, num_per_task
)
dynamic_env.set_regression()
elif version.lower() == "v4":
dynamic_env.set_classification(2)
else:
raise ValueError("Unknown version: {:}".format(version))
return dynamic_env

View File

@@ -49,6 +49,10 @@ class SyntheticDEnv(data.Dataset):
self._meta_info["task"] = "classification"
self._meta_info["num_classes"] = int(num_classes)
@property
def oracle_map(self):
return self._oracle_map
@property
def meta_info(self):
return self._meta_info