60 lines
1.8 KiB
Python
60 lines
1.8 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import math
|
|
from models.MultiHeadAttention import MultiHeadAttention
|
|
from models.FFN import PositionwiseFeedForward
|
|
|
|
class TransformerEncoderBlock(nn.Module):
|
|
'''
|
|
Args:
|
|
dim (int): input embedding's output (emb_dim -> dim)
|
|
ffn_hidden_dim (int): hidden state which in FFN block's dim
|
|
seq (int): sequence's length
|
|
device (nn.Device)
|
|
num_heads (int, default: 8): for Multi-head Attention block
|
|
dropout_rate (float, default=0.1)
|
|
Inputs:
|
|
x: (b, seq, dim)
|
|
Outputs:
|
|
x: (b, seq, dim)
|
|
|
|
'''
|
|
def __init__(self, dim, ffn_hidden_dim, seq, device, num_heads=8, dropout_rate=0.1):
|
|
super(TransformerEncoderBlock, self).__init__()
|
|
|
|
self.layer_norm1 = nn.LayerNorm(dim)
|
|
self.layer_norm2 = nn.LayerNorm(dim)
|
|
self.attention = MultiHeadAttention(dim, num_heads=num_heads)
|
|
self.ffn = PositionwiseFeedForward(dim, ffn_hidden_dim, dropout_rate)
|
|
self.dropout = nn.Dropout(dropout_rate)
|
|
|
|
self.seq = seq
|
|
self.device = device
|
|
|
|
def forward(self, x):
|
|
|
|
# _x is original `x`
|
|
_x = x.clone()
|
|
|
|
# multi-head attention
|
|
x, score = self.attention(k=_x, q=_x, v=_x, mask=None) # b, seq, dim
|
|
x = self.dropout(x)
|
|
|
|
# Add && Norm
|
|
x = _x + x
|
|
x = self.layer_norm1(x) # b, seq, dim
|
|
|
|
# copy x, _x is original `x`
|
|
_x = x.clone()
|
|
|
|
# FFN
|
|
x = self.ffn(x) # b, seq, dim
|
|
x = self.dropout(x)
|
|
|
|
# Add && Norm
|
|
x = _x + x
|
|
x = self.layer_norm2(x)
|
|
|
|
return x
|