31 lines
916 B
Python
31 lines
916 B
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import math
|
|
|
|
class PositionwiseFeedForward(nn.Module):
|
|
'''
|
|
Position-size Feed Forward Network in Transformer block
|
|
|
|
Args:
|
|
dim (int): embedding in transformer block
|
|
hidden (int): hidden state in this block
|
|
dropout_rate (float): dropout layer's dropout rate in this block
|
|
Inputs:
|
|
x: (b, seq, dim)
|
|
Outputs:
|
|
x: (b, seq, dim)
|
|
'''
|
|
def __init__(self, dim, hidden, dropout_rate=0.1):
|
|
super(PositionwiseFeedForward, self).__init__()
|
|
self.linear1 = nn.Linear(dim, hidden)
|
|
self.linear2 = nn.Linear(hidden, dim)
|
|
self.dropout = nn.Dropout(dropout_rate)
|
|
self.relu = nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.relu(self.linear1(x))
|
|
x = self.dropout(x)
|
|
x = self.linear2(x)
|
|
return x
|