Re-organize GeMOSA

This commit is contained in:
D-X-Y
2021-05-27 15:44:01 +08:00
parent 8961215416
commit 6da60664f5
10 changed files with 354 additions and 350 deletions

View File

@@ -198,7 +198,7 @@ class MetaModelV1(super_core.SuperModule):
batch_containers.append(
self._shape_container.translate(torch.split(weights.squeeze(0), 1))
)
return batch_containers, time_embeds
return batch_containers
def forward_raw(self, timestamps, time_embeds, tembed_only=False):
raise NotImplementedError
@@ -206,6 +206,12 @@ class MetaModelV1(super_core.SuperModule):
def forward_candidate(self, input):
raise NotImplementedError
def easy_adapt(self, timestamp, time_embed):
with torch.no_grad():
timestamp = torch.Tensor([timestamp]).to(self._meta_timestamps.device)
self.replace_append_learnt(None, None)
self.append_fixed(timestamp, time_embed)
def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info):
distance = self.get_closest_meta_distance(timestamp)
if distance + self._interval * 1e-2 <= self._interval:
@@ -230,23 +236,20 @@ class MetaModelV1(super_core.SuperModule):
best_new_param = new_param.detach().clone()
for iepoch in range(epochs):
optimizer.zero_grad()
_, time_embed = self(timestamp.view(1), None)
time_embed = self.gen_time_embed(timestamp.view(1))
match_loss = criterion(new_param, time_embed)
[container], time_embed = self(None, new_param.view(1, -1))
[container] = self.gen_model(new_param.view(1, -1))
y_hat = base_model.forward_with_container(x, container)
meta_loss = criterion(y_hat, y)
loss = meta_loss + match_loss
loss.backward()
optimizer.step()
# print("{:03d}/{:03d} : loss : {:.4f} = {:.4f} + {:.4f}".format(iepoch, epochs, loss.item(), meta_loss.item(), match_loss.item()))
if meta_loss.item() < best_loss:
with torch.no_grad():
best_loss = meta_loss.item()
best_new_param = new_param.detach().clone()
with torch.no_grad():
self.replace_append_learnt(None, None)
self.append_fixed(timestamp, best_new_param)
self.easy_adapt(timestamp, best_new_param)
return True, best_loss
def extra_repr(self) -> str: