Update LFNA version 1.0
This commit is contained in:
@@ -23,7 +23,7 @@ lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
||||
if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
|
||||
|
||||
from models.xcore import get_model
|
||||
from datasets.synthetic_core import get_synthetic_env
|
||||
from datasets.synthetic_example import create_example_v1
|
||||
from utils.temp_sync import optimize_fn, evaluate_fn
|
||||
@@ -300,8 +300,20 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"):
|
||||
|
||||
alg_name2dir = OrderedDict()
|
||||
alg_name2dir["Optimal"] = "use-same-timestamp"
|
||||
alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data"
|
||||
colors = ["r", "g"]
|
||||
# alg_name2dir["Supervised Learning (History Data)"] = "use-all-past-data"
|
||||
alg_name2all_containers = OrderedDict()
|
||||
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
|
||||
ckp_path = Path(alg_dir) / xdir / "final-ckp.pth"
|
||||
xdata = torch.load(ckp_path)
|
||||
alg_name2all_containers[alg] = xdata["w_container_per_epoch"]
|
||||
# load the basic model
|
||||
model = get_model(
|
||||
dict(model_type="simple_mlp"),
|
||||
act_cls="leaky_relu",
|
||||
norm_cls="identity",
|
||||
input_dim=1,
|
||||
output_dim=1,
|
||||
)
|
||||
|
||||
alg2xs, alg2ys = defaultdict(list), defaultdict(list)
|
||||
colors = ["r", "g"]
|
||||
@@ -323,6 +335,7 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"):
|
||||
plot_scatter(cur_ax, allx, ally, "k", 0.99, linewidths, "Raw Data")
|
||||
|
||||
for idx_alg, (alg, xdir) in enumerate(alg_name2dir.items()):
|
||||
"""
|
||||
ckp_path = (
|
||||
Path(alg_dir)
|
||||
/ xdir
|
||||
@@ -330,8 +343,12 @@ def compare_algs_v2(save_dir, alg_dir="./outputs/lfna-synthetic"):
|
||||
)
|
||||
assert ckp_path.exists()
|
||||
ckp_data = torch.load(ckp_path)
|
||||
"""
|
||||
with torch.no_grad():
|
||||
predicts = ckp_data["model"](ori_allx)
|
||||
# predicts = ckp_data["model"](ori_allx)
|
||||
predicts = model.forward_with_container(
|
||||
ori_allx, alg_name2all_containers[alg][idx]
|
||||
)
|
||||
predicts = predicts.cpu()
|
||||
# keep data
|
||||
metric = MSEMetric()
|
||||
|
Reference in New Issue
Block a user