Add save/load_best for xlayers
This commit is contained in:
@@ -2,7 +2,9 @@
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||
#####################################################
|
||||
|
||||
import os
|
||||
import abc
|
||||
import tempfile
|
||||
import warnings
|
||||
from typing import Optional, Union, Callable
|
||||
import torch
|
||||
@@ -16,6 +18,9 @@ from .super_utils import LayerOrder, SuperRunMode
|
||||
from .super_utils import TensorContainer
|
||||
from .super_utils import ShapeContainer
|
||||
|
||||
BEST_DIR_KEY = "best_model_dir"
|
||||
BEST_SCORE_KEY = "best_model_score"
|
||||
|
||||
|
||||
class SuperModule(abc.ABC, nn.Module):
|
||||
"""This class equips the nn.Module class with the ability to apply AutoDL."""
|
||||
@@ -25,6 +30,7 @@ class SuperModule(abc.ABC, nn.Module):
|
||||
self._super_run_type = SuperRunMode.Default
|
||||
self._abstract_child = None
|
||||
self._verbose = False
|
||||
self._meta_info = {}
|
||||
|
||||
def set_super_run_type(self, super_run_type):
|
||||
def _reset_super_run(m):
|
||||
@@ -84,6 +90,34 @@ class SuperModule(abc.ABC, nn.Module):
|
||||
total += buf.numel()
|
||||
return total
|
||||
|
||||
def save_best(self, score):
|
||||
if BEST_DIR_KEY not in self._meta_info:
|
||||
tempdir = tempfile.mkdtemp("-xlayers")
|
||||
self._meta_info[BEST_DIR_KEY] = tempdir
|
||||
if BEST_SCORE_KEY not in self._meta_info:
|
||||
self._meta_info[BEST_SCORE_KEY] = None
|
||||
best_score = self._meta_info[BEST_SCORE_KEY]
|
||||
if best_score is None or best_score < score:
|
||||
best_save_path = os.path.join(
|
||||
self._meta_info[BEST_DIR_KEY],
|
||||
"best-{:}.pth".format(self.__class__.__name__),
|
||||
)
|
||||
self._meta_info[BEST_SCORE_KEY] = score
|
||||
torch.save(self.state_dict(), best_save_path)
|
||||
return True, self._meta_info[BEST_SCORE_KEY]
|
||||
else:
|
||||
return False, self._meta_info[BEST_SCORE_KEY]
|
||||
|
||||
def load_best(self):
|
||||
if BEST_DIR_KEY not in self._meta_info or BEST_SCORE_KEY not in self._meta_info:
|
||||
raise ValueError("Please call save_best at first")
|
||||
best_save_path = os.path.join(
|
||||
self._meta_info[BEST_DIR_KEY],
|
||||
"best-{:}.pth".format(self.__class__.__name__),
|
||||
)
|
||||
state_dict = torch.load(best_save_path)
|
||||
self.load_state_dict(state_dict)
|
||||
|
||||
@property
|
||||
def abstract_search_space(self):
|
||||
raise NotImplementedError
|
||||
|
Reference in New Issue
Block a user