Transformer-Translator/models/FFN.py

31 lines
916 B
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class PositionwiseFeedForward(nn.Module):
'''
Position-size Feed Forward Network in Transformer block
Args:
dim (int): embedding in transformer block
hidden (int): hidden state in this block
dropout_rate (float): dropout layer's dropout rate in this block
Inputs:
x: (b, seq, dim)
Outputs:
x: (b, seq, dim)
'''
def __init__(self, dim, hidden, dropout_rate=0.1):
super(PositionwiseFeedForward, self).__init__()
self.linear1 = nn.Linear(dim, hidden)
self.linear2 = nn.Linear(hidden, dim)
self.dropout = nn.Dropout(dropout_rate)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.linear1(x))
x = self.dropout(x)
x = self.linear2(x)
return x