Fix bugs
This commit is contained in:
@@ -66,23 +66,24 @@ def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None):
|
||||
|
||||
def find_min(cur, others):
|
||||
if cur is None:
|
||||
return float(others.min())
|
||||
return float(others)
|
||||
else:
|
||||
return float(min(cur, others.min()))
|
||||
return float(min(cur, others))
|
||||
|
||||
|
||||
def find_max(cur, others):
|
||||
if cur is None:
|
||||
return float(others.max())
|
||||
else:
|
||||
return float(max(cur, others.max()))
|
||||
return float(max(cur, others))
|
||||
|
||||
|
||||
def compare_cl(save_dir):
|
||||
save_dir = Path(str(save_dir))
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
dynamic_env, function = create_example_v1(
|
||||
timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0),
|
||||
# timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0),
|
||||
timestamp_config=None,
|
||||
num_per_task=1000,
|
||||
)
|
||||
|
||||
@@ -104,13 +105,11 @@ def compare_cl(save_dir):
|
||||
current_data["lfna_xaxis_all"] = xaxis_all
|
||||
current_data["lfna_yaxis_all"] = yaxis_all
|
||||
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
|
||||
# compute cl-min
|
||||
cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all)
|
||||
cl_xaxis_max = find_max(cl_xaxis_max, xaxis_all) + idx * 0.1
|
||||
cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all.mean() - xaxis_all.std())
|
||||
cl_xaxis_max = (
|
||||
find_max(cl_xaxis_max, xaxis_all.mean() + xaxis_all.std()) + idx * 0.1
|
||||
)
|
||||
cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.05)
|
||||
|
||||
cl_yaxis_all = cl_function.noise_call(cl_xaxis_all)
|
||||
@@ -142,8 +141,8 @@ def compare_cl(save_dir):
|
||||
"yaxis": cl_yaxis_all,
|
||||
"color": "r",
|
||||
"s": 10,
|
||||
"xlim": (-6, 6 + timestamp * 0.2),
|
||||
"ylim": (-40, 40),
|
||||
"xlim": (round(cl_xaxis_all.min(), 1), round(cl_xaxis_all.max(), 1)),
|
||||
"ylim": (round(cl_xaxis_all.min(), 1), round(cl_yaxis_all.max(), 1)),
|
||||
"alpha": 0.99,
|
||||
"label": "Continual Learning",
|
||||
}
|
||||
@@ -151,10 +150,10 @@ def compare_cl(save_dir):
|
||||
|
||||
draw_multi_fig(
|
||||
save_dir,
|
||||
timestamp,
|
||||
idx,
|
||||
scatter_list,
|
||||
wh=(2000, 1300),
|
||||
fig_title="Timestamp={:03d}".format(timestamp),
|
||||
fig_title="Timestamp={:03d}".format(idx),
|
||||
)
|
||||
print("Save all figures into {:}".format(save_dir))
|
||||
save_dir = save_dir.resolve()
|
||||
|
Reference in New Issue
Block a user