Update the sync data v1

This commit is contained in:
D-X-Y
2021-05-24 13:06:10 +08:00
parent da2575cc6c
commit 3ee0d348af
17 changed files with 228 additions and 274 deletions

View File

@@ -5,15 +5,13 @@ import math
import abc
import copy
import numpy as np
from typing import Optional
import torch
import torch.utils.data as data
class FitFunc(abc.ABC):
"""The fit function that outputs f(x) = a * x^2 + b * x + c."""
def __init__(self, freedom: int, list_of_points=None, params=None):
def __init__(self, freedom: int, list_of_points=None, params=None, xstr="x"):
self._params = dict()
for i in range(freedom):
self._params[i] = None
@@ -24,6 +22,7 @@ class FitFunc(abc.ABC):
self.fit(list_of_points=list_of_points)
if params is not None:
self.set(params)
self._xstr = str(xstr)
def set(self, params):
self._params = copy.deepcopy(params)
@@ -33,6 +32,13 @@ class FitFunc(abc.ABC):
if value is None:
raise ValueError("The {:} is None".format(key))
@property
def xstr(self):
return self._xstr
def reset_xstr(self, xstr):
self._xstr = str(xstr)
@abc.abstractmethod
def __call__(self, x):
raise NotImplementedError
@@ -106,8 +112,8 @@ class FitFunc(abc.ABC):
class LinearFunc(FitFunc):
"""The linear function that outputs f(x) = a * x + b."""
def __init__(self, list_of_points=None, params=None):
super(LinearFunc, self).__init__(2, list_of_points, params)
def __init__(self, list_of_points=None, params=None, xstr="x"):
super(LinearFunc, self).__init__(2, list_of_points, params, xstr)
def __call__(self, x):
self.check_valid()
@@ -117,18 +123,19 @@ class LinearFunc(FitFunc):
return weights[0] * x + weights[1]
def __repr__(self):
return "{name}({a} * x + {b})".format(
return "{name}({a} * {x} + {b})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
x=self.xstr,
)
class QuadraticFunc(FitFunc):
"""The quadratic function that outputs f(x) = a * x^2 + b * x + c."""
def __init__(self, list_of_points=None, params=None):
super(QuadraticFunc, self).__init__(3, list_of_points, params)
def __init__(self, list_of_points=None, params=None, xstr="x"):
super(QuadraticFunc, self).__init__(3, list_of_points, params, xstr)
def __call__(self, x):
self.check_valid()
@@ -138,11 +145,12 @@ class QuadraticFunc(FitFunc):
return weights[0] * x * x + weights[1] * x + weights[2]
def __repr__(self):
return "{name}({a} * x^2 + {b} * x + {c})".format(
return "{name}({a} * {x}^2 + {b} * {x} + {c})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
x=self.xstr,
)
@@ -165,12 +173,13 @@ class CubicFunc(FitFunc):
return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3]
def __repr__(self):
return "{name}({a} * x^3 + {b} * x^2 + {c} * x + {d})".format(
return "{name}({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format(
name=self.__class__.__name__,
a=self._params[0],
b=self._params[1],
c=self._params[2],
d=self._params[3],
x=self.xstr,
)