first commit
This commit is contained in:
38
NAS-Bench-201/models/set_encoder/setenc_models.py
Normal file
38
NAS-Bench-201/models/set_encoder/setenc_models.py
Normal file
@@ -0,0 +1,38 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from .setenc_modules import *
|
||||
|
||||
|
||||
class SetPool(nn.Module):
|
||||
def __init__(self, dim_input, num_outputs, dim_output,
|
||||
num_inds=32, dim_hidden=128, num_heads=4, ln=False, mode=None):
|
||||
super(SetPool, self).__init__()
|
||||
if 'sab' in mode: # [32, 400, 128]
|
||||
self.enc = nn.Sequential(
|
||||
SAB(dim_input, dim_hidden, num_heads, ln=ln), # SAB?
|
||||
SAB(dim_hidden, dim_hidden, num_heads, ln=ln))
|
||||
else: # [32, 400, 128]
|
||||
self.enc = nn.Sequential(
|
||||
ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln), # SAB?
|
||||
ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln))
|
||||
if 'PF' in mode: # [32, 1, 501]
|
||||
self.dec = nn.Sequential(
|
||||
PMA(dim_hidden, num_heads, num_outputs, ln=ln),
|
||||
nn.Linear(dim_hidden, dim_output))
|
||||
elif 'P' in mode:
|
||||
self.dec = nn.Sequential(
|
||||
PMA(dim_hidden, num_heads, num_outputs, ln=ln))
|
||||
else: # torch.Size([32, 1, 501])
|
||||
self.dec = nn.Sequential(
|
||||
PMA(dim_hidden, num_heads, num_outputs, ln=ln), # 32 1 128
|
||||
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
|
||||
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
|
||||
nn.Linear(dim_hidden, dim_output))
|
||||
# "", sm, sab, sabsm
|
||||
|
||||
def forward(self, X):
|
||||
x1 = self.enc(X)
|
||||
x2 = self.dec(x1)
|
||||
return x2
|
Reference in New Issue
Block a user