Update xlayers

This commit is contained in:
D-X-Y
2021-05-22 23:04:24 +08:00
parent 5b09f059fd
commit 8109ed166a
6 changed files with 104 additions and 33 deletions

View File

@@ -117,16 +117,32 @@ class SuperModule(abc.ABC, nn.Module):
else:
return False, self._meta_info[BEST_SCORE_KEY]
def load_best(self):
if BEST_DIR_KEY not in self._meta_info or BEST_SCORE_KEY not in self._meta_info:
raise ValueError("Please call save_best at first")
best_save_path = os.path.join(
self._meta_info[BEST_DIR_KEY],
"best-{:}.pth".format(self.__class__.__name__),
)
def load_best(self, best_save_path=None):
if best_save_path is None:
if (
BEST_DIR_KEY not in self._meta_info
or BEST_SCORE_KEY not in self._meta_info
):
raise ValueError("Please call save_best at first")
best_save_name = self._meta_info.get(
BEST_NAME_KEY, "best-{:}.pth".format(self.__class__.__name__)
)
best_save_path = os.path.join(self._meta_info[BEST_DIR_KEY], best_save_name)
state_dict = torch.load(best_save_path)
self.load_state_dict(state_dict)
def has_best(self, best_name=None):
if BEST_DIR_KEY not in self._meta_info:
raise ValueError("Please set BEST_DIR_KEY at first")
if best_name is None:
best_save_name = self._meta_info.get(
BEST_NAME_KEY, "best-{:}.pth".format(self.__class__.__name__)
)
else:
best_save_name = best_name
best_save_path = os.path.join(self._meta_info[BEST_DIR_KEY], best_save_name)
return os.path.exists(best_save_path)
@property
def abstract_search_space(self):
raise NotImplementedError