make data to use 1 rather than 2 size list

This commit is contained in:
mhz
2024-09-08 23:53:56 +02:00
parent 5dccf590e7
commit 297261d666
2 changed files with 18 additions and 9 deletions

View File

@@ -76,6 +76,8 @@ class CategoricalEmbedder(nn.Module):
embeddings = embeddings + noise
return embeddings
# 相似的condition cluster起来
# size
class ClusterContinuousEmbedder(nn.Module):
def __init__(self, input_size, hidden_size, dropout_prob):
super().__init__()
@@ -108,6 +110,8 @@ class ClusterContinuousEmbedder(nn.Module):
if drop_ids is not None:
embeddings = torch.zeros((labels.shape[0], self.hidden_size), device=labels.device)
# print(labels[~drop_ids].shape)
# torch.Size([1200])
embeddings[~drop_ids] = self.mlp(labels[~drop_ids])
embeddings[drop_ids] += self.embedding_drop.weight[0]
else: