This commit is contained in:
D-X-Y
2021-04-26 21:44:03 +08:00
parent 8358d71cdf
commit d3371296a7
10 changed files with 270 additions and 264 deletions

View File

@@ -66,23 +66,24 @@ def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None):
def find_min(cur, others):
if cur is None:
return float(others.min())
return float(others)
else:
return float(min(cur, others.min()))
return float(min(cur, others))
def find_max(cur, others):
if cur is None:
return float(others.max())
else:
return float(max(cur, others.max()))
return float(max(cur, others))
def compare_cl(save_dir):
save_dir = Path(str(save_dir))
save_dir.mkdir(parents=True, exist_ok=True)
dynamic_env, function = create_example_v1(
timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0),
# timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0),
timestamp_config=None,
num_per_task=1000,
)
@@ -104,13 +105,11 @@ def compare_cl(save_dir):
current_data["lfna_xaxis_all"] = xaxis_all
current_data["lfna_yaxis_all"] = yaxis_all
import pdb
pdb.set_trace()
# compute cl-min
cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all)
cl_xaxis_max = find_max(cl_xaxis_max, xaxis_all) + idx * 0.1
cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all.mean() - xaxis_all.std())
cl_xaxis_max = (
find_max(cl_xaxis_max, xaxis_all.mean() + xaxis_all.std()) + idx * 0.1
)
cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.05)
cl_yaxis_all = cl_function.noise_call(cl_xaxis_all)
@@ -142,8 +141,8 @@ def compare_cl(save_dir):
"yaxis": cl_yaxis_all,
"color": "r",
"s": 10,
"xlim": (-6, 6 + timestamp * 0.2),
"ylim": (-40, 40),
"xlim": (round(cl_xaxis_all.min(), 1), round(cl_xaxis_all.max(), 1)),
"ylim": (round(cl_xaxis_all.min(), 1), round(cl_yaxis_all.max(), 1)),
"alpha": 0.99,
"label": "Continual Learning",
}
@@ -151,10 +150,10 @@ def compare_cl(save_dir):
draw_multi_fig(
save_dir,
timestamp,
idx,
scatter_list,
wh=(2000, 1300),
fig_title="Timestamp={:03d}".format(timestamp),
fig_title="Timestamp={:03d}".format(idx),
)
print("Save all figures into {:}".format(save_dir))
save_dir = save_dir.resolve()