Update MLAML

This commit is contained in:
D-X-Y
2021-05-27 17:41:32 +00:00
parent c6db1ef65a
commit 9af34ea94d
3 changed files with 21 additions and 27 deletions

View File

@@ -5,26 +5,16 @@
#####################################################
import unittest
from xautodl.datasets.math_core import ConstantFunc, ComposedSinSFunc
from xautodl.datasets.synthetic_core import SyntheticDEnv
from xautodl.datasets.synthetic_core import get_synthetic_env
class TestSynethicEnv(unittest.TestCase):
"""Test the synethtic environment."""
def test_simple(self):
mean_generator = ConstantFunc(constant=0.1)
std_generator = ConstantFunc(constant=0.5)
dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000)
print(dataset)
for timestamp, tau in dataset:
self.assertEqual(tau.shape, (5000, 1))
def test_length(self):
mean_generator = ComposedSinSFunc({0: 1, 1: 1, 2: 3})
std_generator = ConstantFunc(constant=0.5)
dataset = SyntheticDEnv([mean_generator], [[std_generator]], num_per_task=5000)
self.assertEqual(len(dataset), 100)
dataset = SyntheticDEnv([mean_generator], [[std_generator]], mode="train")
self.assertEqual(len(dataset), 60)
versions = ["v1", "v2", "v3", "v4"]
for version in versions:
env = get_synthetic_env(version=version)
print(env)
for timestamp, tau in env:
self.assertEqual(tau.shape, (1000, env.ndim))