Add save/load_best for xlayers

This commit is contained in:
D-X-Y
2021-05-13 07:57:41 +00:00
parent a2b1d0d227
commit d1836cbe52
4 changed files with 73 additions and 38 deletions

View File

@@ -2,7 +2,9 @@
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import os
import abc
import tempfile
import warnings
from typing import Optional, Union, Callable
import torch
@@ -16,6 +18,9 @@ from .super_utils import LayerOrder, SuperRunMode
from .super_utils import TensorContainer
from .super_utils import ShapeContainer
BEST_DIR_KEY = "best_model_dir"
BEST_SCORE_KEY = "best_model_score"
class SuperModule(abc.ABC, nn.Module):
"""This class equips the nn.Module class with the ability to apply AutoDL."""
@@ -25,6 +30,7 @@ class SuperModule(abc.ABC, nn.Module):
self._super_run_type = SuperRunMode.Default
self._abstract_child = None
self._verbose = False
self._meta_info = {}
def set_super_run_type(self, super_run_type):
def _reset_super_run(m):
@@ -84,6 +90,34 @@ class SuperModule(abc.ABC, nn.Module):
total += buf.numel()
return total
def save_best(self, score):
if BEST_DIR_KEY not in self._meta_info:
tempdir = tempfile.mkdtemp("-xlayers")
self._meta_info[BEST_DIR_KEY] = tempdir
if BEST_SCORE_KEY not in self._meta_info:
self._meta_info[BEST_SCORE_KEY] = None
best_score = self._meta_info[BEST_SCORE_KEY]
if best_score is None or best_score < score:
best_save_path = os.path.join(
self._meta_info[BEST_DIR_KEY],
"best-{:}.pth".format(self.__class__.__name__),
)
self._meta_info[BEST_SCORE_KEY] = score
torch.save(self.state_dict(), best_save_path)
return True, self._meta_info[BEST_SCORE_KEY]
else:
return False, self._meta_info[BEST_SCORE_KEY]
def load_best(self):
if BEST_DIR_KEY not in self._meta_info or BEST_SCORE_KEY not in self._meta_info:
raise ValueError("Please call save_best at first")
best_save_path = os.path.join(
self._meta_info[BEST_DIR_KEY],
"best-{:}.pth".format(self.__class__.__name__),
)
state_dict = torch.load(best_save_path)
self.load_state_dict(state_dict)
@property
def abstract_search_space(self):
raise NotImplementedError