Transformer-Translator/models/TransformerEncoderBlock.py

60 lines
1.8 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from models.MultiHeadAttention import MultiHeadAttention
from models.FFN import PositionwiseFeedForward
class TransformerEncoderBlock(nn.Module):
'''
Args:
dim (int): input embedding's output (emb_dim -> dim)
ffn_hidden_dim (int): hidden state which in FFN block's dim
seq (int): sequence's length
device (nn.Device)
num_heads (int, default: 8): for Multi-head Attention block
dropout_rate (float, default=0.1)
Inputs:
x: (b, seq, dim)
Outputs:
x: (b, seq, dim)
'''
def __init__(self, dim, ffn_hidden_dim, seq, device, num_heads=8, dropout_rate=0.1):
super(TransformerEncoderBlock, self).__init__()
self.layer_norm1 = nn.LayerNorm(dim)
self.layer_norm2 = nn.LayerNorm(dim)
self.attention = MultiHeadAttention(dim, num_heads=num_heads)
self.ffn = PositionwiseFeedForward(dim, ffn_hidden_dim, dropout_rate)
self.dropout = nn.Dropout(dropout_rate)
self.seq = seq
self.device = device
def forward(self, x):
# _x is original `x`
_x = x.clone()
# multi-head attention
x, score = self.attention(k=_x, q=_x, v=_x, mask=None) # b, seq, dim
x = self.dropout(x)
# Add && Norm
x = _x + x
x = self.layer_norm1(x) # b, seq, dim
# copy x, _x is original `x`
_x = x.clone()
# FFN
x = self.ffn(x) # b, seq, dim
x = self.dropout(x)
# Add && Norm
x = _x + x
x = self.layer_norm2(x)
return x