use trainer but has bugs

This commit is contained in:
mhz
2024-09-19 14:11:19 +02:00
parent d36e1d1077
commit be178bc5ee
6 changed files with 750 additions and 580 deletions

View File

@@ -2,20 +2,23 @@ general:
name: 'graph_dit'
wandb: 'disabled'
gpus: 1
gpu_number: 2
gpu_number: 0
resume: null
test_only: null
sample_every_val: 2500
samples_to_generate: 512
samples_to_generate: 1000
samples_to_save: 3
chains_to_save: 1
log_every_steps: 50
number_chain_steps: 8
final_model_samples_to_generate: 100
final_model_samples_to_generate: 1000
final_model_samples_to_save: 20
final_model_chains_to_save: 1
enable_progress_bar: False
save_model: True
log_dir: '/nfs/data3/hanzhang/nasbenchDiT'
number_checkpoint_limit: 3
type: 'Trainer'
model:
type: 'discrete'
transition: 'marginal'
@@ -32,7 +35,7 @@ model:
ensure_connected: True
train:
# n_epochs: 5000
n_epochs: 500
n_epochs: 10
batch_size: 1200
lr: 0.0002
clip_grad: null
@@ -41,8 +44,11 @@ train:
seed: 0
val_check_interval: null
check_val_every_n_epoch: 1
gradient_accumulation_steps: 1
dataset:
datadir: 'data/'
task_name: 'nasbench-201'
guidance_target: 'nasbench-201'
pin_memory: False
ppo:
clip_param: 1