Update LFNA version 1.0

This commit is contained in:
D-X-Y
2021-05-07 14:27:15 +08:00
parent 80aaac4dfa
commit 34560ad8d1
5 changed files with 120 additions and 40 deletions

View File

@@ -23,7 +23,7 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from models.xcore import get_model
from datasets.synthetic_core import get_synthetic_env
from datasets.synthetic_example import create_example_v1
from utils.temp_sync import optimize_fn, evaluate_fn
@@ -300,8 +300,20 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"):
alg_name2dir = OrderedDict()
alg_name2dir["Optimal"] = "use-same-timestamp"
alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data"
colors = ["r", "g"]
# alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data"
alg_name2all_containers = OrderedDict()
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
ckp_path = Path(alg_dir) / xdir / "final-ckp.pth"
xdata = torch.load(ckp_path)
alg_name2all_containers[alg] = xdata["w_container_per_epoch"]
# load the basic model
model = get_model(
dict(model_type="simple_mlp"),
act_cls="leaky_relu",
norm_cls="identity",
input_dim=1,
output_dim=1,
)
alg2xs, alg2ys = defaultdict(list), defaultdict(list)
colors = ["r", "g"]
@@ -323,6 +335,7 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"):
plot_scatter(cur_ax, allx, ally, "k", 0.99, linewidths, "Raw Data")
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
"""
ckp_path = (
Path(alg_dir)
/ xdir
@@ -330,8 +343,12 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"):
)
assert ckp_path.exists()
ckp_data = torch.load(ckp_path)
"""
with torch.no_grad():
predicts = ckp_data["model"](ori_allx)
# predicts = ckp_data["model"](ori_allx)
predicts = model.forward_with_container(
ori_allx, alg_name2all_containers[alg][idx]
)
predicts = predicts.cpu()
# keep data
metric = MSEMetric()