Update codes

This commit is contained in:
D-X-Y
2021-04-26 06:16:08 -07:00
parent e1818694a4
commit 8358d71cdf
4 changed files with 76 additions and 31 deletions

View File

@@ -24,10 +24,7 @@ if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from datasets import ConstantGenerator, SinGenerator, SyntheticDEnv
from datasets import DynamicQuadraticFunc
from datasets.synthetic_example import create_example_v1
from utils.temp_sync import optimize_fn, evaluate_fn
@@ -61,43 +58,72 @@ def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None):
tick.label.set_rotation(10)
for tick in cur_ax.yaxis.get_major_ticks():
tick.label.set_fontsize(LabelSize - font_gap)
plt.legend(loc=1, fontsize=LegendFontsize)
cur_ax.legend(loc=1, fontsize=LegendFontsize)
fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf")
fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png")
plt.close("all")
def find_min(cur, others):
if cur is None:
return float(others.min())
else:
return float(min(cur, others.min()))
def find_max(cur, others):
if cur is None:
return float(others.max())
else:
return float(max(cur, others.max()))
def compare_cl(save_dir):
save_dir = Path(str(save_dir))
save_dir.mkdir(parents=True, exist_ok=True)
dynamic_env, function = create_example_v1(100, num_per_task=1000)
dynamic_env, function = create_example_v1(
timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0),
num_per_task=1000,
)
additional_xaxis = np.arange(-6, 6, 0.2)
models = dict()
cl_function = copy.deepcopy(function)
cl_function.set_timestamp(0)
cl_xaxis_all = None
cl_xaxis_min = None
cl_xaxis_max = None
all_data = OrderedDict()
for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)):
xaxis_all = dataset[:, 0].numpy()
# xaxis_all = np.concatenate((additional_xaxis, xaxis_all))
# compute the ground truth
current_data = dict()
function.set_timestamp(timestamp)
yaxis_all = function.noise_call(xaxis_all)
current_data["lfna_xaxis_all"] = xaxis_all
current_data["lfna_yaxis_all"] = yaxis_all
# create CL data
if cl_xaxis_all is None:
cl_xaxis_all = xaxis_all
else:
cl_xaxis_all = np.concatenate((cl_xaxis_all, xaxis_all + timestamp * 0.2))
cl_yaxis_all = cl_function(cl_xaxis_all)
import pdb
pdb.set_trace()
# compute cl-min
cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all)
cl_xaxis_max = find_max(cl_xaxis_max, xaxis_all) + idx * 0.1
cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.05)
cl_yaxis_all = cl_function.noise_call(cl_xaxis_all)
current_data["cl_xaxis_all"] = cl_xaxis_all
current_data["cl_yaxis_all"] = cl_yaxis_all
all_data[timestamp] = current_data
for idx, (timestamp, xdata) in enumerate(tqdm(all_data.items(), ncols=50)):
scatter_list = []
scatter_list.append(
{
"xaxis": xaxis_all,
"yaxis": yaxis_all,
"xaxis": xdata["lfna_xaxis_all"],
"yaxis": xdata["lfna_yaxis_all"],
"color": "k",
"s": 10,
"alpha": 0.99,
@@ -107,6 +133,9 @@ def compare_cl(save_dir):
}
)
cl_xaxis_all = current_data["cl_xaxis_all"]
cl_yaxis_all = current_data["cl_yaxis_all"]
scatter_list.append(
{
"xaxis": cl_xaxis_all,
@@ -121,15 +150,21 @@ def compare_cl(save_dir):
)
draw_multi_fig(
save_dir, timestamp, scatter_list,
wh=(2000, 1300), fig_title="Timestamp={:03d}".format(timestamp)
save_dir,
timestamp,
scatter_list,
wh=(2000, 1300),
fig_title="Timestamp={:03d}".format(timestamp),
)
print("Save all figures into {:}".format(save_dir))
save_dir = save_dir.resolve()
cmd = "ffmpeg -y -i {xdir}/%04d.png -pix_fmt yuv420p -vf fps=2 -vf scale=2000:1300 -vb 5000k {xdir}/vis.mp4".format(
xdir=save_dir
base_cmd = (
"ffmpeg -y -i {xdir}/%04d.png -vf fps=2 -vf scale=2000:1300 -vb 5000k".format(
xdir=save_dir
)
)
os.system(cmd)
os.system("{:} -pix_fmt yuv420p {xdir}/vis.mp4".format(base_cmd, xdir=save_dir))
os.system("{:} -c:a libvorbis {xdir}/vis.webm".format(base_cmd, xdir=save_dir))
if __name__ == "__main__":