first commit
This commit is contained in:
38
NAS-Bench-201/models/set_encoder/setenc_models.py
Normal file
38
NAS-Bench-201/models/set_encoder/setenc_models.py
Normal file
@@ -0,0 +1,38 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from .setenc_modules import *
|
||||
|
||||
|
||||
class SetPool(nn.Module):
|
||||
def __init__(self, dim_input, num_outputs, dim_output,
|
||||
num_inds=32, dim_hidden=128, num_heads=4, ln=False, mode=None):
|
||||
super(SetPool, self).__init__()
|
||||
if 'sab' in mode: # [32, 400, 128]
|
||||
self.enc = nn.Sequential(
|
||||
SAB(dim_input, dim_hidden, num_heads, ln=ln), # SAB?
|
||||
SAB(dim_hidden, dim_hidden, num_heads, ln=ln))
|
||||
else: # [32, 400, 128]
|
||||
self.enc = nn.Sequential(
|
||||
ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln), # SAB?
|
||||
ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln))
|
||||
if 'PF' in mode: # [32, 1, 501]
|
||||
self.dec = nn.Sequential(
|
||||
PMA(dim_hidden, num_heads, num_outputs, ln=ln),
|
||||
nn.Linear(dim_hidden, dim_output))
|
||||
elif 'P' in mode:
|
||||
self.dec = nn.Sequential(
|
||||
PMA(dim_hidden, num_heads, num_outputs, ln=ln))
|
||||
else: # torch.Size([32, 1, 501])
|
||||
self.dec = nn.Sequential(
|
||||
PMA(dim_hidden, num_heads, num_outputs, ln=ln), # 32 1 128
|
||||
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
|
||||
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
|
||||
nn.Linear(dim_hidden, dim_output))
|
||||
# "", sm, sab, sabsm
|
||||
|
||||
def forward(self, X):
|
||||
x1 = self.enc(X)
|
||||
x2 = self.dec(x1)
|
||||
return x2
|
67
NAS-Bench-201/models/set_encoder/setenc_modules.py
Normal file
67
NAS-Bench-201/models/set_encoder/setenc_modules.py
Normal file
@@ -0,0 +1,67 @@
|
||||
#####################################################################################
|
||||
# Copyright (c) Juho Lee SetTransformer, ICML 2019 [GitHub set_transformer]
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
|
||||
######################################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
class MAB(nn.Module):
|
||||
def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
|
||||
super(MAB, self).__init__()
|
||||
self.dim_V = dim_V
|
||||
self.num_heads = num_heads
|
||||
self.fc_q = nn.Linear(dim_Q, dim_V)
|
||||
self.fc_k = nn.Linear(dim_K, dim_V)
|
||||
self.fc_v = nn.Linear(dim_K, dim_V)
|
||||
if ln:
|
||||
self.ln0 = nn.LayerNorm(dim_V)
|
||||
self.ln1 = nn.LayerNorm(dim_V)
|
||||
self.fc_o = nn.Linear(dim_V, dim_V)
|
||||
|
||||
def forward(self, Q, K):
|
||||
Q = self.fc_q(Q)
|
||||
K, V = self.fc_k(K), self.fc_v(K)
|
||||
|
||||
dim_split = self.dim_V // self.num_heads
|
||||
Q_ = torch.cat(Q.split(dim_split, 2), 0)
|
||||
K_ = torch.cat(K.split(dim_split, 2), 0)
|
||||
V_ = torch.cat(V.split(dim_split, 2), 0)
|
||||
|
||||
A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
|
||||
O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
|
||||
O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
|
||||
O = O + F.relu(self.fc_o(O))
|
||||
O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
|
||||
return O
|
||||
|
||||
class SAB(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, num_heads, ln=False):
|
||||
super(SAB, self).__init__()
|
||||
self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln)
|
||||
|
||||
def forward(self, X):
|
||||
return self.mab(X, X)
|
||||
|
||||
class ISAB(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
|
||||
super(ISAB, self).__init__()
|
||||
self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
|
||||
nn.init.xavier_uniform_(self.I)
|
||||
self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln)
|
||||
self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)
|
||||
|
||||
def forward(self, X):
|
||||
H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
|
||||
return self.mab1(X, H)
|
||||
|
||||
class PMA(nn.Module):
|
||||
def __init__(self, dim, num_heads, num_seeds, ln=False):
|
||||
super(PMA, self).__init__()
|
||||
self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
|
||||
nn.init.xavier_uniform_(self.S)
|
||||
self.mab = MAB(dim, dim, dim, num_heads, ln=ln)
|
||||
|
||||
def forward(self, X):
|
||||
return self.mab(self.S.repeat(X.size(0), 1, 1), X)
|
Reference in New Issue
Block a user