add naswot
This commit is contained in:
98
graph_dit/naswot/pycls/core/checkpoint.py
Normal file
98
graph_dit/naswot/pycls/core/checkpoint.py
Normal file
@@ -0,0 +1,98 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Functions that handle saving and loading of checkpoints."""
|
||||
|
||||
import os
|
||||
|
||||
import pycls.core.distributed as dist
|
||||
import torch
|
||||
from pycls.core.config import cfg
|
||||
|
||||
|
||||
# Common prefix for checkpoint file names
|
||||
_NAME_PREFIX = "model_epoch_"
|
||||
# Checkpoints directory name
|
||||
_DIR_NAME = "checkpoints"
|
||||
|
||||
|
||||
def get_checkpoint_dir():
|
||||
"""Retrieves the location for storing checkpoints."""
|
||||
return os.path.join(cfg.OUT_DIR, _DIR_NAME)
|
||||
|
||||
|
||||
def get_checkpoint(epoch):
|
||||
"""Retrieves the path to a checkpoint file."""
|
||||
name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch)
|
||||
return os.path.join(get_checkpoint_dir(), name)
|
||||
|
||||
|
||||
def get_last_checkpoint():
|
||||
"""Retrieves the most recent checkpoint (highest epoch number)."""
|
||||
checkpoint_dir = get_checkpoint_dir()
|
||||
# Checkpoint file names are in lexicographic order
|
||||
checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f]
|
||||
last_checkpoint_name = sorted(checkpoints)[-1]
|
||||
return os.path.join(checkpoint_dir, last_checkpoint_name)
|
||||
|
||||
|
||||
def has_checkpoint():
|
||||
"""Determines if there are checkpoints available."""
|
||||
checkpoint_dir = get_checkpoint_dir()
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
return False
|
||||
return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir))
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, epoch):
|
||||
"""Saves a checkpoint."""
|
||||
# Save checkpoints only from the master process
|
||||
if not dist.is_master_proc():
|
||||
return
|
||||
# Ensure that the checkpoint dir exists
|
||||
os.makedirs(get_checkpoint_dir(), exist_ok=True)
|
||||
# Omit the DDP wrapper in the multi-gpu setting
|
||||
sd = model.module.state_dict() if cfg.NUM_GPUS > 1 else model.state_dict()
|
||||
# Record the state
|
||||
if isinstance(optimizer, list):
|
||||
checkpoint = {
|
||||
"epoch": epoch,
|
||||
"model_state": sd,
|
||||
"optimizer_w_state": optimizer[0].state_dict(),
|
||||
"optimizer_a_state": optimizer[1].state_dict(),
|
||||
"cfg": cfg.dump(),
|
||||
}
|
||||
else:
|
||||
checkpoint = {
|
||||
"epoch": epoch,
|
||||
"model_state": sd,
|
||||
"optimizer_state": optimizer.state_dict(),
|
||||
"cfg": cfg.dump(),
|
||||
}
|
||||
# Write the checkpoint
|
||||
checkpoint_file = get_checkpoint(epoch + 1)
|
||||
torch.save(checkpoint, checkpoint_file)
|
||||
return checkpoint_file
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint_file, model, optimizer=None):
|
||||
"""Loads the checkpoint from the given file."""
|
||||
err_str = "Checkpoint '{}' not found"
|
||||
assert os.path.exists(checkpoint_file), err_str.format(checkpoint_file)
|
||||
# Load the checkpoint on CPU to avoid GPU mem spike
|
||||
checkpoint = torch.load(checkpoint_file, map_location="cpu")
|
||||
# Account for the DDP wrapper in the multi-gpu setting
|
||||
ms = model.module if cfg.NUM_GPUS > 1 else model
|
||||
ms.load_state_dict(checkpoint["model_state"])
|
||||
# Load the optimizer state (commonly not done when fine-tuning)
|
||||
if optimizer:
|
||||
if isinstance(optimizer, list):
|
||||
optimizer[0].load_state_dict(checkpoint["optimizer_w_state"])
|
||||
optimizer[1].load_state_dict(checkpoint["optimizer_a_state"])
|
||||
else:
|
||||
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
||||
return checkpoint["epoch"]
|
Reference in New Issue
Block a user