Update the sync data v1
This commit is contained in:
@@ -5,15 +5,13 @@ import math
|
||||
import abc
|
||||
import copy
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
|
||||
|
||||
class FitFunc(abc.ABC):
|
||||
"""The fit function that outputs f(x) = a * x^2 + b * x + c."""
|
||||
|
||||
def __init__(self, freedom: int, list_of_points=None, params=None):
|
||||
def __init__(self, freedom: int, list_of_points=None, params=None, xstr="x"):
|
||||
self._params = dict()
|
||||
for i in range(freedom):
|
||||
self._params[i] = None
|
||||
@@ -24,6 +22,7 @@ class FitFunc(abc.ABC):
|
||||
self.fit(list_of_points=list_of_points)
|
||||
if params is not None:
|
||||
self.set(params)
|
||||
self._xstr = str(xstr)
|
||||
|
||||
def set(self, params):
|
||||
self._params = copy.deepcopy(params)
|
||||
@@ -33,6 +32,13 @@ class FitFunc(abc.ABC):
|
||||
if value is None:
|
||||
raise ValueError("The {:} is None".format(key))
|
||||
|
||||
@property
|
||||
def xstr(self):
|
||||
return self._xstr
|
||||
|
||||
def reset_xstr(self, xstr):
|
||||
self._xstr = str(xstr)
|
||||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, x):
|
||||
raise NotImplementedError
|
||||
@@ -106,8 +112,8 @@ class FitFunc(abc.ABC):
|
||||
class LinearFunc(FitFunc):
|
||||
"""The linear function that outputs f(x) = a * x + b."""
|
||||
|
||||
def __init__(self, list_of_points=None, params=None):
|
||||
super(LinearFunc, self).__init__(2, list_of_points, params)
|
||||
def __init__(self, list_of_points=None, params=None, xstr="x"):
|
||||
super(LinearFunc, self).__init__(2, list_of_points, params, xstr)
|
||||
|
||||
def __call__(self, x):
|
||||
self.check_valid()
|
||||
@@ -117,18 +123,19 @@ class LinearFunc(FitFunc):
|
||||
return weights[0] * x + weights[1]
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a} * x + {b})".format(
|
||||
return "{name}({a} * {x} + {b})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
x=self.xstr,
|
||||
)
|
||||
|
||||
|
||||
class QuadraticFunc(FitFunc):
|
||||
"""The quadratic function that outputs f(x) = a * x^2 + b * x + c."""
|
||||
|
||||
def __init__(self, list_of_points=None, params=None):
|
||||
super(QuadraticFunc, self).__init__(3, list_of_points, params)
|
||||
def __init__(self, list_of_points=None, params=None, xstr="x"):
|
||||
super(QuadraticFunc, self).__init__(3, list_of_points, params, xstr)
|
||||
|
||||
def __call__(self, x):
|
||||
self.check_valid()
|
||||
@@ -138,11 +145,12 @@ class QuadraticFunc(FitFunc):
|
||||
return weights[0] * x * x + weights[1] * x + weights[2]
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a} * x^2 + {b} * x + {c})".format(
|
||||
return "{name}({a} * {x}^2 + {b} * {x} + {c})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
c=self._params[2],
|
||||
x=self.xstr,
|
||||
)
|
||||
|
||||
|
||||
@@ -165,12 +173,13 @@ class CubicFunc(FitFunc):
|
||||
return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3]
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a} * x^3 + {b} * x^2 + {c} * x + {d})".format(
|
||||
return "{name}({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
c=self._params[2],
|
||||
d=self._params[3],
|
||||
x=self.xstr,
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user