Update xlayers
This commit is contained in:
@@ -117,16 +117,32 @@ class SuperModule(abc.ABC, nn.Module):
|
||||
else:
|
||||
return False, self._meta_info[BEST_SCORE_KEY]
|
||||
|
||||
def load_best(self):
|
||||
if BEST_DIR_KEY not in self._meta_info or BEST_SCORE_KEY not in self._meta_info:
|
||||
raise ValueError("Please call save_best at first")
|
||||
best_save_path = os.path.join(
|
||||
self._meta_info[BEST_DIR_KEY],
|
||||
"best-{:}.pth".format(self.__class__.__name__),
|
||||
)
|
||||
def load_best(self, best_save_path=None):
|
||||
if best_save_path is None:
|
||||
if (
|
||||
BEST_DIR_KEY not in self._meta_info
|
||||
or BEST_SCORE_KEY not in self._meta_info
|
||||
):
|
||||
raise ValueError("Please call save_best at first")
|
||||
best_save_name = self._meta_info.get(
|
||||
BEST_NAME_KEY, "best-{:}.pth".format(self.__class__.__name__)
|
||||
)
|
||||
best_save_path = os.path.join(self._meta_info[BEST_DIR_KEY], best_save_name)
|
||||
state_dict = torch.load(best_save_path)
|
||||
self.load_state_dict(state_dict)
|
||||
|
||||
def has_best(self, best_name=None):
|
||||
if BEST_DIR_KEY not in self._meta_info:
|
||||
raise ValueError("Please set BEST_DIR_KEY at first")
|
||||
if best_name is None:
|
||||
best_save_name = self._meta_info.get(
|
||||
BEST_NAME_KEY, "best-{:}.pth".format(self.__class__.__name__)
|
||||
)
|
||||
else:
|
||||
best_save_name = best_name
|
||||
best_save_path = os.path.join(self._meta_info[BEST_DIR_KEY], best_save_name)
|
||||
return os.path.exists(best_save_path)
|
||||
|
||||
@property
|
||||
def abstract_search_space(self):
|
||||
raise NotImplementedError
|
||||
|
Reference in New Issue
Block a user