set batch_y to 1 and want to test 15625
This commit is contained in:
@@ -356,7 +356,8 @@ class Graph_DiT(pl.LightningModule):
|
||||
to_generate = min(samples_left_to_generate, bs)
|
||||
to_save = min(samples_left_to_save, bs)
|
||||
chains_save = min(chains_left_to_save, bs)
|
||||
batch_y = test_y_collection[batch_id : batch_id + to_generate]
|
||||
# batch_y = test_y_collection[batch_id : batch_id + to_generate]
|
||||
batch_y = torch.ones(to_generate, self.ydim_output, device=self.device)
|
||||
|
||||
cur_sample = self.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,
|
||||
keep_chain=chains_save, number_chain_steps=self.number_chain_steps)
|
||||
|
Reference in New Issue
Block a user