Re-organize GeMOSA
This commit is contained in:
@@ -198,7 +198,7 @@ class MetaModelV1(super_core.SuperModule):
|
||||
batch_containers.append(
|
||||
self._shape_container.translate(torch.split(weights.squeeze(0), 1))
|
||||
)
|
||||
return batch_containers, time_embeds
|
||||
return batch_containers
|
||||
|
||||
def forward_raw(self, timestamps, time_embeds, tembed_only=False):
|
||||
raise NotImplementedError
|
||||
@@ -206,6 +206,12 @@ class MetaModelV1(super_core.SuperModule):
|
||||
def forward_candidate(self, input):
|
||||
raise NotImplementedError
|
||||
|
||||
def easy_adapt(self, timestamp, time_embed):
|
||||
with torch.no_grad():
|
||||
timestamp = torch.Tensor([timestamp]).to(self._meta_timestamps.device)
|
||||
self.replace_append_learnt(None, None)
|
||||
self.append_fixed(timestamp, time_embed)
|
||||
|
||||
def adapt(self, base_model, criterion, timestamp, x, y, lr, epochs, init_info):
|
||||
distance = self.get_closest_meta_distance(timestamp)
|
||||
if distance + self._interval * 1e-2 <= self._interval:
|
||||
@@ -230,23 +236,20 @@ class MetaModelV1(super_core.SuperModule):
|
||||
best_new_param = new_param.detach().clone()
|
||||
for iepoch in range(epochs):
|
||||
optimizer.zero_grad()
|
||||
_, time_embed = self(timestamp.view(1), None)
|
||||
time_embed = self.gen_time_embed(timestamp.view(1))
|
||||
match_loss = criterion(new_param, time_embed)
|
||||
|
||||
[container], time_embed = self(None, new_param.view(1, -1))
|
||||
[container] = self.gen_model(new_param.view(1, -1))
|
||||
y_hat = base_model.forward_with_container(x, container)
|
||||
meta_loss = criterion(y_hat, y)
|
||||
loss = meta_loss + match_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
# print("{:03d}/{:03d} : loss : {:.4f} = {:.4f} + {:.4f}".format(iepoch, epochs, loss.item(), meta_loss.item(), match_loss.item()))
|
||||
if meta_loss.item() < best_loss:
|
||||
with torch.no_grad():
|
||||
best_loss = meta_loss.item()
|
||||
best_new_param = new_param.detach().clone()
|
||||
with torch.no_grad():
|
||||
self.replace_append_learnt(None, None)
|
||||
self.append_fixed(timestamp, best_new_param)
|
||||
self.easy_adapt(timestamp, best_new_param)
|
||||
return True, best_loss
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
|
Reference in New Issue
Block a user