add sample phase and try to get log prob

This commit is contained in:
mhz
2024-09-08 23:26:49 +02:00
parent 0c4b597dd2
commit 5dccf590e7
2 changed files with 76 additions and 26 deletions

View File

@@ -1,5 +1,5 @@
# These imports are tricky because they use c++, do not move them
import tqdm
from tqdm import tqdm
import os, shutil
import warnings
@@ -232,29 +232,64 @@ def test(cfg: DictConfig):
optimizer.step()
optimizer.zero_grad()
# return {'loss': loss}
# start testing
print("start testing")
graph_dit_model.eval()
test_dataloader = accelerator.prepare(datamodule.test_dataloader())
for data in test_dataloader:
data_x = F.one_hot(data.x, num_classes=12).float()[:, graph_dit_model.active_index]
data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float()
dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, graph_dit_model.max_n_nodes)
dense_data = dense_data.mask(node_mask)
noisy_data = graph_dit_model.apply_noise(dense_data.X, dense_data.E, data.y, node_mask)
pred = graph_dit_model.forward(noisy_data)
nll = graph_dit_model.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=True)
graph_dit_model.test_y_collection.append(data.y)
print(f'test loss: {nll}')
# start sampling
samples = []
samples_left_to_generate = cfg.general.final_model_samples_to_generate
samples_left_to_save = cfg.general.final_model_samples_to_save
chains_left_to_save = cfg.general.final_model_chains_to_save
for i in tqdm(
range(cfg.general.n_samples), desc="Sampling", disable=not cfg.general.enable_progress_bar
):
batch_size = cfg.train.batch_size
num_steps = cfg.model.diffusion_steps
y = torch.ones(batch_size, num_steps, 1, 1, device=accelerator.device, dtype=inference_dtype)
samples, all_ys, batch_id = [], [], 0
test_y_collection = torch.cat(graph_dit_model.test_y_collection, dim=0)
num_examples = test_y_collection.size(0)
if cfg.general.final_model_samples_to_generate > num_examples:
ratio = cfg.general.final_model_samples_to_generate // num_examples
test_y_collection = test_y_collection.repeat(ratio+1, 1)
num_examples = test_y_collection.size(0)
while samples_left_to_generate > 0:
print(f'samples left to generate: {samples_left_to_generate}/'
f'{cfg.general.final_model_samples_to_generate}', end='', flush=True)
bs = 1 * cfg.train.batch_size
to_generate = min(samples_left_to_generate, bs)
to_save = min(samples_left_to_save, bs)
chains_save = min(chains_left_to_save, bs)
# batch_y = test_y_collection[batch_id : batch_id + to_generate]
batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device)
# sample from the model
samples_batch = graph_dit_model.sample_batch(
batch_id=i,
batch_size=batch_size,
y=y,
keep_chain=1,
number_chain_steps=num_steps,
save_final=batch_size
)
samples.append(samples_batch)
cur_sample = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps)[0]
samples = samples + cur_sample
all_ys.append(batch_y)
batch_id += to_generate
samples_left_to_save -= to_save
samples_left_to_generate -= to_generate
chains_left_to_save -= chains_save
print(f"final Computing sampling metrics...")
graph_dit_model.sampling_metrics.reset()
graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True)
graph_dit_model.sampling_metrics.reset()
print(f"Done.")
# save samples
print("Samples:")
print(samples)