feat: Transformer-based Translator
This commit is contained in:
parent
c36b0bab3f
commit
25cfa08b27
29458
data/cmn.txt
Normal file
29458
data/cmn.txt
Normal file
File diff suppressed because it is too large
Load Diff
29458
data/cmn_zh_tw.txt
Normal file
29458
data/cmn_zh_tw.txt
Normal file
File diff suppressed because it is too large
Load Diff
23
data/zh_ch_transform.py
Normal file
23
data/zh_ch_transform.py
Normal file
@ -0,0 +1,23 @@
|
||||
from opencc import OpenCC
|
||||
|
||||
def zh_ch_transform():
|
||||
'''
|
||||
轉換 cmn.txt 簡體字 -> 繁體字
|
||||
輸出到 cmn_zh_tw.txt
|
||||
Only run once
|
||||
'''
|
||||
with open('cmn.txt') as fp:
|
||||
lines = fp.readlines()
|
||||
|
||||
newLines = []
|
||||
cc = OpenCC('s2t')
|
||||
for line in lines:
|
||||
e, simple_c, _ = line.split('\t')
|
||||
trandition_c = cc.convert(simple_c)
|
||||
newLines.append("{}\t{}".format(e, trandition_c))
|
||||
|
||||
with open("cmn_zh_tw.txt", 'w') as fp:
|
||||
for line in newLines:
|
||||
fp.write("{}\n".format(line))
|
||||
|
||||
zh_ch_transform()
|
||||
30
models/FFN.py
Normal file
30
models/FFN.py
Normal file
@ -0,0 +1,30 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
class PositionwiseFeedForward(nn.Module):
|
||||
'''
|
||||
Position-size Feed Forward Network in Transformer block
|
||||
|
||||
Args:
|
||||
dim (int): embedding in transformer block
|
||||
hidden (int): hidden state in this block
|
||||
dropout_rate (float): dropout layer's dropout rate in this block
|
||||
Inputs:
|
||||
x: (b, seq, dim)
|
||||
Outputs:
|
||||
x: (b, seq, dim)
|
||||
'''
|
||||
def __init__(self, dim, hidden, dropout_rate=0.1):
|
||||
super(PositionwiseFeedForward, self).__init__()
|
||||
self.linear1 = nn.Linear(dim, hidden)
|
||||
self.linear2 = nn.Linear(hidden, dim)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(self.linear1(x))
|
||||
x = self.dropout(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
59
models/MultiHeadAttention.py
Normal file
59
models/MultiHeadAttention.py
Normal file
@ -0,0 +1,59 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
'''
|
||||
Multi-Head Self Attention Block
|
||||
|
||||
Args:
|
||||
dim (int): input & output dim
|
||||
num_heads (int, default=8): number of heads
|
||||
Inputs:
|
||||
k: (b, seq, dim), it's not key value from anywhere, it's an embedding ready to get into W_k
|
||||
q: (b, seq, dim), like k
|
||||
v: (b, seq, dim), like v
|
||||
mask (default None): BoolTensor, (b, seq, dim)
|
||||
Outputs:
|
||||
ans: (b, seq, dim)
|
||||
score: (b, #heads, seq, seq) attention score which after softmax
|
||||
'''
|
||||
def __init__(self, dim, num_heads=8):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.wk = nn.Linear(dim, dim) # b, seq, dim
|
||||
self.wq = nn.Linear(dim, dim) # b, seq, dim
|
||||
self.wv = nn.Linear(dim, dim) # b, seq, dim
|
||||
self.fc = nn.Linear(dim, dim)
|
||||
|
||||
def forward(self, k, q, v, mask=None):
|
||||
b, seq, dim = k.shape
|
||||
k = self.wk(k) # b, seq, dim
|
||||
q = self.wq(q) # b, seq, dim
|
||||
v = self.wv(v) # b, seq, dim
|
||||
|
||||
k = k.view(b, -1, self.num_heads, self.head_dim).transpose(1, 2) # b, #heads, seq, #head_dim
|
||||
q = q.view(b, -1, self.num_heads, self.head_dim).transpose(1, 2) # b, #heads, seq, #head_dim
|
||||
v = v.view(b, -1, self.num_heads, self.head_dim).transpose(1, 2) # b, #heads, seq, #head_dim
|
||||
|
||||
k = k.transpose(2, 3) # b, #heads, #head_dim, seq
|
||||
|
||||
score = torch.matmul(q, k) / (math.sqrt(self.head_dim)) # b, #heads, seq, seq
|
||||
if mask != None:
|
||||
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
|
||||
score = score.masked_fill(mask, value=torch.tensor(-(1e20)))
|
||||
# print(score[0][0][2])
|
||||
# for i in score[0][0]:
|
||||
# print(i)
|
||||
score = F.softmax(score, dim=-1)
|
||||
|
||||
ans = torch.matmul(score, v) # b, #heads, seq, head_dim
|
||||
|
||||
ans = ans.transpose(1, 2).reshape((b, -1, dim)) # b, seq, dim
|
||||
ans = self.fc(ans) # b, seq, dim
|
||||
|
||||
return ans, score
|
||||
31
models/PositionEncode.py
Normal file
31
models/PositionEncode.py
Normal file
@ -0,0 +1,31 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
class PositionEncode(nn.Module):
|
||||
'''
|
||||
Args:
|
||||
emb_dim (int): position embedding dim
|
||||
device (nn.device)
|
||||
Inputs:
|
||||
time_seq: LongTensor (b, )
|
||||
'''
|
||||
def __init__(self, emb_dim, device):
|
||||
super(PositionEncode, self).__init__()
|
||||
seq = torch.tensor([ i//2 for i in range(emb_dim) ]) / emb_dim
|
||||
self.base = 1/torch.pow(10000, seq).to(device) # (dim, )
|
||||
self.emb_dim = emb_dim
|
||||
|
||||
def forward(self, time_seq):
|
||||
b = time_seq.shape[0]
|
||||
base = self.base[:, None].reshape(1, -1).repeat(b, 1) # (b, dim)
|
||||
time_seq = time_seq[:, None]
|
||||
# .repeat(1, self.emb_dim) # (b, dim)
|
||||
|
||||
|
||||
ans = base * time_seq # (b, dim)
|
||||
ans[:, 0::2] = torch.sin(ans[:, 0::2])
|
||||
ans[:, 1::2] = torch.cos(ans[:, 1::2])
|
||||
|
||||
return ans
|
||||
45
models/Transformer.py
Normal file
45
models/Transformer.py
Normal file
@ -0,0 +1,45 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from models.TransformerEncoderBlock import TransformerEncoderBlock
|
||||
from models.TransformerDecoderBlock import TransformerDecoderBlock
|
||||
from models.TransformerEmbedding import TransformerEmbedding
|
||||
|
||||
class Transformer(nn.Module):
|
||||
'''
|
||||
Args:
|
||||
emb_dim (int): word embedding dim (input dim)
|
||||
dim (int): dim in transformer blocks
|
||||
ffn_hidden_dim (int): dim in FFN, bigger than dim
|
||||
encoder_seq (int): encoder input's length
|
||||
decoder_seq (int): decoder input's length
|
||||
device (nn.Device)
|
||||
num_heads (int, default=8)
|
||||
dropout_rate (float, default=0.1)
|
||||
Inputs:
|
||||
encoder_input: (b, encoder_seq, emb_dim)
|
||||
decoder_input: (b, decoder_seq, emb_dim)
|
||||
Outputs:
|
||||
decoder_output: (b, decoder_seq, dim)
|
||||
'''
|
||||
def __init__(self, emb_dim, dim, ffn_hidden_dim, encoder_seq, decoder_seq, device, num_heads=8, dropout_rate=0.1):
|
||||
super(Transformer, self).__init__()
|
||||
self.input_embedding = TransformerEmbedding(emb_dim, dim, encoder_seq, device)
|
||||
self.output_embedding = TransformerEmbedding(emb_dim, dim, decoder_seq, device)
|
||||
self.encoders = nn.ModuleList([
|
||||
TransformerEncoderBlock(dim, ffn_hidden_dim, encoder_seq, device, num_heads, dropout_rate) for _ in range(4)
|
||||
])
|
||||
self.decoders = nn.ModuleList([
|
||||
TransformerDecoderBlock(dim, ffn_hidden_dim, decoder_seq, device, num_heads, dropout_rate) for _ in range(4)
|
||||
])
|
||||
|
||||
def forward(self, encoder_input, decoder_input):
|
||||
encoder_input = self.input_embedding(encoder_input)
|
||||
decoder_input = self.output_embedding(decoder_input)
|
||||
for layer in self.encoders:
|
||||
encoder_input = layer(encoder_input)
|
||||
for layer in self.decoders:
|
||||
decoder_input, score = layer(decoder_input, encoder_input)
|
||||
return decoder_input, score
|
||||
|
||||
91
models/TransformerDecoderBlock.py
Normal file
91
models/TransformerDecoderBlock.py
Normal file
@ -0,0 +1,91 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from models.MultiHeadAttention import MultiHeadAttention
|
||||
from models.FFN import PositionwiseFeedForward
|
||||
|
||||
class TransformerDecoderBlock(nn.Module):
|
||||
'''
|
||||
Args:
|
||||
dim (int): input embedding's output (emb_dim -> dim)
|
||||
ffn_hidden_dim (int): hidden state which in FFN block's dim
|
||||
seq (int): sequence's length
|
||||
device (nn.Device)
|
||||
num_heads (int, default: 8): for Multi-head Attention block
|
||||
dropout_rate (float, default: 0.1)
|
||||
Inputs:
|
||||
x: (b, seq, dim)
|
||||
enc: (b, encoder_seq, dim), encoder's output memory, encoder_seq is not arguments.
|
||||
Outputs:
|
||||
x: (b, seq, dim)
|
||||
score: (cross attention) (b, #heads, seq, seq)
|
||||
'''
|
||||
def __init__(self, dim, ffn_hidden_dim, seq, device, num_heads=8, dropout_rate=0.1):
|
||||
super(TransformerDecoderBlock, self).__init__()
|
||||
|
||||
self.layer_norm1 = nn.LayerNorm(dim)
|
||||
self.layer_norm2 = nn.LayerNorm(dim)
|
||||
self.layer_norm3 = nn.LayerNorm(dim)
|
||||
self.attention = MultiHeadAttention(dim, num_heads=num_heads)
|
||||
self.ffn = PositionwiseFeedForward(dim, ffn_hidden_dim, dropout_rate)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
self.seq = seq
|
||||
self.device = device
|
||||
|
||||
def forward(self, x, enc):
|
||||
b = x.shape[0]
|
||||
|
||||
# copy x, _x is original `x`
|
||||
_x = x.clone()
|
||||
|
||||
# multi-head attention
|
||||
mask = self.getMask(b) # b, seq, seq
|
||||
x, score = self.attention(k=x, q=x, v=x, mask=mask) # b, seq, dim
|
||||
x = self.dropout(x)
|
||||
|
||||
# Add && Norm
|
||||
x = _x + x
|
||||
x = self.layer_norm1(x) # b, seq, dim
|
||||
|
||||
# copy x, _x is original `x`
|
||||
_x = x.clone()
|
||||
|
||||
# mutl-head attention with encoder's memory
|
||||
x, score = self.attention(k=enc, q=x, v=enc) # b, seq, dim
|
||||
x = self.dropout(x)
|
||||
|
||||
# Add && Norm
|
||||
x = _x + x
|
||||
x = self.layer_norm2(x) # b, seq, dim
|
||||
|
||||
# copy x, _x is original `x`
|
||||
_x = x.clone()
|
||||
|
||||
# FFN
|
||||
x = self.ffn(x) # b, seq, dim
|
||||
x = self.dropout(x)
|
||||
|
||||
# Add && Norm
|
||||
x = _x + x
|
||||
x = self.layer_norm3(x)
|
||||
|
||||
return x, score
|
||||
|
||||
def getMask(self, batch_size):
|
||||
'''
|
||||
Return (b, seq, seq) mask
|
||||
0 1 1 1 1
|
||||
0 0 1 1 1
|
||||
0 0 0 1 1
|
||||
0 0 0 0 1
|
||||
0 0 0 0 0
|
||||
|
||||
Inputs:
|
||||
batch_size (int)
|
||||
'''
|
||||
mask = torch.triu(torch.ones((self.seq, self.seq), dtype=torch.bool), diagonal=1) # (seq, seq)
|
||||
mask = mask.unsqueeze(0).repeat(batch_size, 1, 1) # (b, seq, seq)
|
||||
mask = mask.to(self.device)
|
||||
return mask
|
||||
34
models/TransformerEmbedding.py
Normal file
34
models/TransformerEmbedding.py
Normal file
@ -0,0 +1,34 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from models.PositionEncode import PositionEncode
|
||||
|
||||
class TransformerEmbedding(nn.Module):
|
||||
def __init__(self, emb_dim, dim, seq, device, dropout_rate=0.1):
|
||||
super(TransformerEmbedding, self).__init__()
|
||||
self.position_encoding = PositionEncode(dim, device)
|
||||
self.input_embedding = nn.Linear(emb_dim, dim)
|
||||
|
||||
self.dim = dim
|
||||
self.seq = seq
|
||||
self.device = device
|
||||
|
||||
def forward(self, x):
|
||||
b = x.shape[0]
|
||||
x = self.input_embedding(x)
|
||||
position_emb = self.getPositionEncoding(b) # b, seq, dim
|
||||
x += position_emb
|
||||
return x
|
||||
|
||||
def getPositionEncoding(self, batch_size):
|
||||
'''
|
||||
Return (b, seq, dim) position encode
|
||||
|
||||
Inputs:
|
||||
batch_size (int)
|
||||
'''
|
||||
time_seq = torch.LongTensor(range(self.seq)).to(self.device)
|
||||
emb = self.position_encoding(time_seq) # (seq, dim)
|
||||
emb = emb[:, :, None].permute(2, 0, 1).repeat(batch_size, 1, 1) # (batch_size, seq, dim)
|
||||
return emb
|
||||
59
models/TransformerEncoderBlock.py
Normal file
59
models/TransformerEncoderBlock.py
Normal file
@ -0,0 +1,59 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from models.MultiHeadAttention import MultiHeadAttention
|
||||
from models.FFN import PositionwiseFeedForward
|
||||
|
||||
class TransformerEncoderBlock(nn.Module):
|
||||
'''
|
||||
Args:
|
||||
dim (int): input embedding's output (emb_dim -> dim)
|
||||
ffn_hidden_dim (int): hidden state which in FFN block's dim
|
||||
seq (int): sequence's length
|
||||
device (nn.Device)
|
||||
num_heads (int, default: 8): for Multi-head Attention block
|
||||
dropout_rate (float, default=0.1)
|
||||
Inputs:
|
||||
x: (b, seq, dim)
|
||||
Outputs:
|
||||
x: (b, seq, dim)
|
||||
|
||||
'''
|
||||
def __init__(self, dim, ffn_hidden_dim, seq, device, num_heads=8, dropout_rate=0.1):
|
||||
super(TransformerEncoderBlock, self).__init__()
|
||||
|
||||
self.layer_norm1 = nn.LayerNorm(dim)
|
||||
self.layer_norm2 = nn.LayerNorm(dim)
|
||||
self.attention = MultiHeadAttention(dim, num_heads=num_heads)
|
||||
self.ffn = PositionwiseFeedForward(dim, ffn_hidden_dim, dropout_rate)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
self.seq = seq
|
||||
self.device = device
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# _x is original `x`
|
||||
_x = x.clone()
|
||||
|
||||
# multi-head attention
|
||||
x, score = self.attention(k=_x, q=_x, v=_x, mask=None) # b, seq, dim
|
||||
x = self.dropout(x)
|
||||
|
||||
# Add && Norm
|
||||
x = _x + x
|
||||
x = self.layer_norm1(x) # b, seq, dim
|
||||
|
||||
# copy x, _x is original `x`
|
||||
_x = x.clone()
|
||||
|
||||
# FFN
|
||||
x = self.ffn(x) # b, seq, dim
|
||||
x = self.dropout(x)
|
||||
|
||||
# Add && Norm
|
||||
x = _x + x
|
||||
x = self.layer_norm2(x)
|
||||
|
||||
return x
|
||||
92
models/Translator.py
Normal file
92
models/Translator.py
Normal file
@ -0,0 +1,92 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from models.Transformer import Transformer
|
||||
import os
|
||||
|
||||
class TranslatorModel(nn.Module):
|
||||
'''
|
||||
Args:
|
||||
num_emb_en (int): length of vocab (how many words in vocab)
|
||||
num_emb_ch (int): length of vocab (how many words in vocab)
|
||||
emb_dim (int): word embedding dim (English's dim and Chineses' dim are same)
|
||||
en_vocab (torchtext.Vocab): for load Glove pretrained embedding
|
||||
dim_in_transformer (int)
|
||||
ffn_hidden_dim (int): for transformer's FFN module
|
||||
en_seq (int): English token' length
|
||||
ch_seq (int): Chinese token' length
|
||||
device (nn.Device)
|
||||
|
||||
num_heads (int, default: 8)
|
||||
dropout_rate (float, default: 0.1)
|
||||
Inputs:
|
||||
en_tokens: (b, seq) LongTensor
|
||||
ch_tokens: (b, seq) LongTensor
|
||||
Outputs:
|
||||
x: (b, seq, num_emb_ch) probability
|
||||
'''
|
||||
def __init__(self, num_emb_en, num_emb_ch, emb_dim, en_vocab, dim_in_transformer, ffn_hidden_dim, en_seq, ch_seq, device, num_heads=8, dropout_rate=0.1):
|
||||
super(TranslatorModel, self).__init__()
|
||||
|
||||
# load glove word embedding
|
||||
weight = self.get_glove_weight(en_vocab, num_emb_en, emb_dim)
|
||||
self.en_word_embedding = nn.Embedding(num_emb_en, emb_dim)
|
||||
self.en_word_embedding = self.en_word_embedding.from_pretrained(weight, freeze=True)
|
||||
|
||||
# chinese word embedding
|
||||
self.ch_word_embedding = nn.Embedding(num_emb_ch, emb_dim)
|
||||
|
||||
# transformer
|
||||
self.transformer = Transformer(
|
||||
emb_dim=emb_dim,
|
||||
dim=dim_in_transformer,
|
||||
ffn_hidden_dim=ffn_hidden_dim,
|
||||
encoder_seq=en_seq,
|
||||
decoder_seq=ch_seq,
|
||||
device=device,
|
||||
num_heads=8,
|
||||
dropout_rate=0.1
|
||||
)
|
||||
|
||||
self.fc1 = nn.Linear(dim_in_transformer, 512)
|
||||
self.fc2 = nn.Linear(512, 1024)
|
||||
self.fc3 = nn.Linear(1024, num_emb_ch)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
|
||||
def forward(self, en_tokens, ch_tokens):
|
||||
en_tokens = self.en_word_embedding(en_tokens)
|
||||
ch_tokens = self.ch_word_embedding(ch_tokens)
|
||||
x, score = self.transformer(en_tokens, ch_tokens)
|
||||
x = self.relu(self.fc1(x))
|
||||
x = self.dropout(x)
|
||||
x = self.relu(self.fc2(x))
|
||||
x = self.dropout(x)
|
||||
x = self.fc3(x)
|
||||
return x, score
|
||||
|
||||
def get_glove_weight(self, en_vocab, num_emb_en, emb_dim):
|
||||
'''
|
||||
Load embedding from GLOVE
|
||||
|
||||
Args:
|
||||
en_vocab (torch.Vocab)
|
||||
num_emb_en (int): (how many word in vocab)
|
||||
emb_dim (int): word embedding's dim
|
||||
'''
|
||||
if os.path.isfile("data/word_embedding.pth"):
|
||||
weight = torch.load("data/word_embedding.pth")
|
||||
else:
|
||||
weight = torch.randn((num_emb_en, emb_dim))
|
||||
with open('data/glove.6B.100d.txt') as fp:
|
||||
lines = fp.readlines()
|
||||
for line in lines:
|
||||
l = line.split(" ")
|
||||
word = l[0]
|
||||
emb = l[1:]
|
||||
emb = torch.tensor([ float(i) for i in emb ])
|
||||
if word in en_vocab:
|
||||
weight[en_vocab[word]] = emb
|
||||
torch.save(weight, "data/word_embedding.pth")
|
||||
return weight
|
||||
|
||||
96
predict.py
Normal file
96
predict.py
Normal file
@ -0,0 +1,96 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchtext
|
||||
import numpy
|
||||
from utils.Vocab import get_vocabs
|
||||
from utils.Dataset import TranslateDataset
|
||||
from models.Transformer import Transformer
|
||||
from models.Translator import TranslatorModel
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
BATCH_SIZE = 128
|
||||
EPOCH_NUM = 100
|
||||
LEARNING_RATE = 1e-4
|
||||
ENGLISH_SEQ = 50
|
||||
CHINESE_SEQ = 40
|
||||
WORD_EMB_DIM = 100
|
||||
DIM_IN_TRANSFORMER = 256
|
||||
FFN_HIDDEN_DIM = 512
|
||||
DEVICE = torch.device('cuda')
|
||||
SHOW_NUM = 5
|
||||
NUM_HEADS = 8
|
||||
DROPOUT_RATE = 0.5
|
||||
|
||||
|
||||
def predict(en_str, model, en_vocab, ch_vocab):
|
||||
|
||||
en_tokens = en2tokens(en_str, en_vocab, for_model=True, en_seq=ENGLISH_SEQ)
|
||||
en_tokens = en_tokens.unsqueeze(0).to(DEVICE)
|
||||
|
||||
ch_tokens = torch.LongTensor([ ch_vocab['<PAD>'] for _ in range(CHINESE_SEQ) ]).unsqueeze(0).to(DEVICE)
|
||||
ch_tokens[0][0] = torch.tensor(ch_vocab['<SOS>'])
|
||||
|
||||
model.eval()
|
||||
att = []
|
||||
with torch.no_grad():
|
||||
for index in range(0, CHINESE_SEQ):
|
||||
predict, score = model(en_tokens, ch_tokens) # b, seq, dim
|
||||
predict = torch.argmax(predict, dim=2) # b, seq
|
||||
att.append(score[0, :, index, :].unsqueeze(0))
|
||||
if index != (CHINESE_SEQ-1):
|
||||
ch_tokens[0][index+1] = predict[0][index]
|
||||
att = torch.cat(att, dim=0) # seq, #head, ENGLISH_SEQ
|
||||
|
||||
english_words = en_vocab.lookup_tokens(en_tokens[0].tolist())
|
||||
chinese_words = ch_vocab.lookup_tokens(ch_tokens[0].tolist())
|
||||
|
||||
english_len, chinese_len = 0, 0
|
||||
for i in english_words:
|
||||
english_len += 1
|
||||
if i == '<END>':
|
||||
break
|
||||
for i in chinese_words:
|
||||
chinese_len += 1
|
||||
if i == '<END>':
|
||||
break
|
||||
|
||||
return chinese_words, english_words, english_len, chinese_len, att
|
||||
|
||||
if __name__ == '__main__':
|
||||
# load tokenizer & vocabs
|
||||
tokenizer = torchtext.data.utils.get_tokenizer("basic_english")
|
||||
en_vocab, ch_vocab = get_vocabs()
|
||||
print("English Vocab size: {}\nChinese Vocab size: {}\n".format(len(en_vocab), len(ch_vocab)))
|
||||
|
||||
# load model
|
||||
model = TranslatorModel(
|
||||
num_emb_en=len(en_vocab),
|
||||
num_emb_ch=len(ch_vocab),
|
||||
emb_dim=WORD_EMB_DIM,
|
||||
en_vocab=en_vocab,
|
||||
dim_in_transformer=DIM_IN_TRANSFORMER,
|
||||
ffn_hidden_dim=FFN_HIDDEN_DIM,
|
||||
en_seq=ENGLISH_SEQ,
|
||||
ch_seq=CHINESE_SEQ,
|
||||
device=DEVICE,
|
||||
num_heads=NUM_HEADS,
|
||||
dropout_rate=DROPOUT_RATE
|
||||
).to(DEVICE)
|
||||
|
||||
model.load_state_dict(torch.load('translate_model.pth'))
|
||||
|
||||
while(1):
|
||||
s = input("English: ")
|
||||
chinese_words, english_words, english_len, chinese_len, att = predict(s, model, en_vocab, ch_vocab)
|
||||
att = att.to('cpu').numpy()
|
||||
|
||||
print("Chinese: {}".format(chinese_words))
|
||||
|
||||
for i in range(8):
|
||||
mapping = att[:chinese_len, i, :english_len]
|
||||
|
||||
plt.rcParams['font.sans-serif'] = ['Taipei Sans TC Beta']
|
||||
plt.subplot(421+i)
|
||||
sns.heatmap(mapping, xticklabels=english_words[:english_len], yticklabels=chinese_words[:chinese_len])
|
||||
plt.show()
|
||||
127
train.py
Normal file
127
train.py
Normal file
@ -0,0 +1,127 @@
|
||||
from torchinfo import summary
|
||||
import torchtext
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy
|
||||
from utils.Vocab import get_vocabs
|
||||
from utils.Dataset import TranslateDataset
|
||||
from models.Transformer import Transformer
|
||||
from models.Translator import TranslatorModel
|
||||
|
||||
def train():
|
||||
BATCH_SIZE = 128
|
||||
EPOCH_NUM = 100
|
||||
LEARNING_RATE = 1e-4
|
||||
ENGLISH_SEQ = 50
|
||||
CHINESE_SEQ = 40
|
||||
WORD_EMB_DIM = 100
|
||||
DIM_IN_TRANSFORMER = 256
|
||||
FFN_HIDDEN_DIM = 512
|
||||
DEVICE = torch.device('cuda')
|
||||
SHOW_NUM = 5
|
||||
NUM_HEADS = 8
|
||||
DROPOUT_RATE = 0.3
|
||||
|
||||
tokenizer = torchtext.data.utils.get_tokenizer("basic_english")
|
||||
en_vocab, ch_vocab = get_vocabs()
|
||||
print("English Vocab size: {}\nChinese Vocab size: {}\n".format(len(en_vocab), len(ch_vocab)))
|
||||
|
||||
train_set = TranslateDataset(10, tokenizer, en_vocab, ch_vocab, ENGLISH_SEQ, CHINESE_SEQ)
|
||||
val_set = TranslateDataset(10, tokenizer, en_vocab, ch_vocab, ENGLISH_SEQ, CHINESE_SEQ, val=True)
|
||||
print("Train set: {}\nValid set: {}\n".format(len(train_set), len(val_set)))
|
||||
|
||||
|
||||
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True, prefetch_factor=2)
|
||||
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True, prefetch_factor=2)
|
||||
|
||||
model = TranslatorModel(
|
||||
num_emb_en=len(en_vocab),
|
||||
num_emb_ch=len(ch_vocab),
|
||||
emb_dim=WORD_EMB_DIM,
|
||||
en_vocab=en_vocab,
|
||||
dim_in_transformer=DIM_IN_TRANSFORMER,
|
||||
ffn_hidden_dim=FFN_HIDDEN_DIM,
|
||||
en_seq=ENGLISH_SEQ,
|
||||
ch_seq=CHINESE_SEQ,
|
||||
device=DEVICE,
|
||||
num_heads=NUM_HEADS,
|
||||
dropout_rate=DROPOUT_RATE
|
||||
).to(DEVICE)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
||||
|
||||
min_val_loss = 99
|
||||
for epoch in range(EPOCH_NUM):
|
||||
model.train()
|
||||
|
||||
loss_sum = {'train': 0, 'val': 0}
|
||||
acc_sum = {'train': 0, 'val': 0}
|
||||
count = {'train': 0, 'val': 0}
|
||||
for (en_tokens, ch_tokens), y in train_loader:
|
||||
b = len(y)
|
||||
count['train'] += b
|
||||
|
||||
en_tokens, ch_tokens = en_tokens.to(DEVICE), ch_tokens.to(DEVICE)
|
||||
y = y.to(DEVICE)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
prediction = model(en_tokens, ch_tokens) # (b, seq, dim)
|
||||
prediction = prediction.view(-1, len(ch_vocab)) # (b*seq, dim)
|
||||
y = y.view(-1) # (b*seq)
|
||||
|
||||
loss = criterion(prediction, y)
|
||||
loss_sum['train'] += loss.item()
|
||||
# print(loss.item())
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
acc_sum['train'] += (torch.argmax(prediction, dim=1) == y).sum()
|
||||
|
||||
prediction = torch.argmax(prediction, dim=1).view(b, -1)
|
||||
for index, seq in enumerate(prediction):
|
||||
print(en_vocab.lookup_tokens(en_tokens[index].tolist()))
|
||||
print(ch_vocab.lookup_tokens(seq.tolist()))
|
||||
print()
|
||||
if index >= SHOW_NUM:
|
||||
break
|
||||
|
||||
# val
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for (en_tokens, ch_tokens), y in val_loader:
|
||||
b = len(y)
|
||||
count['val'] += b
|
||||
|
||||
en_tokens, ch_tokens = en_tokens.to(DEVICE), ch_tokens.to(DEVICE)
|
||||
y = y.to(DEVICE)
|
||||
|
||||
prediction = model(en_tokens, ch_tokens) # (b, seq, dim)
|
||||
prediction = prediction.view(-1, len(ch_vocab)) # (b*seq, dim)
|
||||
y = y.view(-1) # (b*seq)
|
||||
|
||||
loss = criterion(prediction, y)
|
||||
loss_sum['val'] += loss.item()
|
||||
# print(loss.item())
|
||||
|
||||
acc_sum['val'] += (torch.argmax(prediction, dim=1) == y).sum()
|
||||
|
||||
prediction = torch.argmax(prediction, dim=1).view(b, -1)
|
||||
for index, seq in enumerate(prediction):
|
||||
print(en_vocab.lookup_tokens(en_tokens[index].tolist()))
|
||||
print(ch_vocab.lookup_tokens(seq.tolist()))
|
||||
print()
|
||||
if index >= SHOW_NUM:
|
||||
break
|
||||
print(count)
|
||||
print(loss_sum)
|
||||
print("EPOCH {}: with lr={}, loss: {} acc: {}".format(epoch, LEARNING_RATE, loss_sum['train']/len(train_loader), acc_sum['train']/count['train']/CHINESE_SEQ))
|
||||
print("EPOCH {}: with lr={}, loss: {} acc: {} (val)".format(epoch, LEARNING_RATE, loss_sum['val']/len(val_loader), acc_sum['val']/count['val']/CHINESE_SEQ))
|
||||
if((loss_sum['val']/len(val_loader)) < min_val_loss):
|
||||
min_val_loss = loss_sum['val']/len(val_loader)
|
||||
torch.save(model.state_dict(), 'translate_model.pth')
|
||||
print("MIN: {}".format(min_val_loss))
|
||||
print()
|
||||
train()
|
||||
59
utils/Dataset.py
Normal file
59
utils/Dataset.py
Normal file
@ -0,0 +1,59 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import random
|
||||
import jieba
|
||||
|
||||
class TranslateDataset(Dataset):
|
||||
'''
|
||||
Generate en, ch tokens (numeric token)
|
||||
|
||||
Args:
|
||||
random_seed (int)
|
||||
tokenizer (torchtext tokenizer): English
|
||||
en_vocab (torchtext.Vocab): English ver
|
||||
ch_vocab (torchtext.Vocab): Chinese ver
|
||||
en_seq (int): english token's length (it will padding to this length)
|
||||
ch_seq (int): chinese token's length (it will padding to this length)
|
||||
train_ratio (float, default: 0.8)
|
||||
val (bool, default: False)
|
||||
'''
|
||||
def __init__(self, random_seed, tokenizer, en_vocab, ch_vocab, en_seq, ch_seq, train_ratio=0.8, val=False):
|
||||
super(Dataset, self).__init__()
|
||||
random.seed(random_seed)
|
||||
|
||||
self.en_vocab = en_vocab
|
||||
self.ch_vocab = ch_vocab
|
||||
|
||||
# read file
|
||||
with open('data/cmn_zh_tw.txt') as fp:
|
||||
lines = fp.readlines()
|
||||
length = len(lines)
|
||||
|
||||
# random & split
|
||||
random.shuffle(lines)
|
||||
if val:
|
||||
self.data = lines[ int(length*train_ratio): ]
|
||||
else:
|
||||
self.data = lines[ :int(length*train_ratio) ]
|
||||
|
||||
# tokenizer
|
||||
self.en_data, self.ch_data = [], []
|
||||
for index, line in enumerate(self.data):
|
||||
en, ch = line.replace('\n', '').split('\t')
|
||||
|
||||
en_tokens = en_vocab(tokenizer(en.lower()))
|
||||
en_tokens = [ en_vocab['<SOS>'] ] + en_tokens + [ en_vocab['<END>'] ]
|
||||
en_tokens = en_tokens + [ en_vocab['<PAD>'] for _ in range(en_seq - len(en_tokens)) ]
|
||||
self.en_data.append(en_tokens)
|
||||
|
||||
ch_tokens = ch_vocab(list(jieba.cut(ch)))
|
||||
ch_tokens = [ ch_vocab['<SOS>'] ] + ch_tokens + [ ch_vocab['<END>'] ]
|
||||
ch_tokens = ch_tokens + [ ch_vocab['<PAD>'] for _ in range(ch_seq - len(ch_tokens)) ]
|
||||
self.ch_data.append(ch_tokens)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.en_data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
target = torch.LongTensor( self.ch_data[index][1:] + [ self.ch_vocab['<PAD>'] ] )
|
||||
return (torch.LongTensor(self.en_data[index]), torch.LongTensor(self.ch_data[index])), target
|
||||
72
utils/Vocab.py
Normal file
72
utils/Vocab.py
Normal file
@ -0,0 +1,72 @@
|
||||
import torchtext
|
||||
import jieba
|
||||
import logging
|
||||
jieba.setLogLevel(logging.INFO)
|
||||
|
||||
def en_tokenizer_yeild(sentences):
|
||||
'''
|
||||
for building torchtext.vocab.Vocab (English)
|
||||
it use get_tokenizer() function to tokenizer English sentences
|
||||
then yield tokens to build_vocab_from_iterator() function to generate Vocab
|
||||
|
||||
Args:
|
||||
sentences (list[str]): not case sensitive
|
||||
'''
|
||||
tokenizer = torchtext.data.utils.get_tokenizer("basic_english")
|
||||
for sentence in sentences:
|
||||
yield tokenizer(sentence.lower())
|
||||
|
||||
def ch_tokenizer_yeild(sentences):
|
||||
'''
|
||||
for building torchtext.vocab.Vocab (Chinese)
|
||||
it use jieba.cut function to tokenizer Chinese sentences
|
||||
then yield tokens to build_vocab_from_iterator() function to generate Vocab
|
||||
|
||||
Args:
|
||||
sentences (list[str])
|
||||
'''
|
||||
|
||||
for sentence in sentences:
|
||||
yield list(jieba.cut(sentence))
|
||||
|
||||
def generate_vocab(sentences, yield_f):
|
||||
'''
|
||||
Generate English or Chinese Vocab (torchtext.Vocab)
|
||||
|
||||
Args:
|
||||
sentences (list[str]): English or Chinese sentences's list
|
||||
yield_f (function): en_tokenizer_yeild or ch_tokenizer_yeild, depends on which language's vocab to generate
|
||||
Outputs:
|
||||
vocab: (torchtext.Vocab)
|
||||
'''
|
||||
vocab = torchtext.vocab.build_vocab_from_iterator(
|
||||
yield_f(sentences),
|
||||
min_freq=1,
|
||||
special_first=True,
|
||||
specials=["<SOS>", "<END>", "<UNK>", "<PAD>"]
|
||||
)
|
||||
vocab.set_default_index(vocab['<UNK>'])
|
||||
return vocab
|
||||
|
||||
def get_vocabs():
|
||||
'''
|
||||
Generate English & Chinese two Vocab (torchtext.Vocab)
|
||||
|
||||
Args:
|
||||
None
|
||||
Outputs:
|
||||
en_vocab, ch_vocab: (torchtext.Vocab)
|
||||
'''
|
||||
with open('data/cmn_zh_tw.txt') as fp:
|
||||
sentences = fp.readlines()
|
||||
|
||||
en_sentences, ch_sentences = [], []
|
||||
for index, line in enumerate(sentences):
|
||||
en, ch = line.replace('\n', '').split('\t')
|
||||
en_sentences.append( en.lower() )
|
||||
ch_sentences.append( ch )
|
||||
|
||||
en_vocab = generate_vocab(en_sentences, en_tokenizer_yeild)
|
||||
ch_vocab = generate_vocab(ch_sentences, ch_tokenizer_yeild)
|
||||
return en_vocab, ch_vocab
|
||||
|
||||
Loading…
Reference in New Issue
Block a user