Transformer-Translator/models/Translator.py

93 lines
3.3 KiB
Python

import torch
import torch.nn as nn
from models.Transformer import Transformer
import os
class TranslatorModel(nn.Module):
'''
Args:
num_emb_en (int): length of vocab (how many words in vocab)
num_emb_ch (int): length of vocab (how many words in vocab)
emb_dim (int): word embedding dim (English's dim and Chineses' dim are same)
en_vocab (torchtext.Vocab): for load Glove pretrained embedding
dim_in_transformer (int)
ffn_hidden_dim (int): for transformer's FFN module
en_seq (int): English token' length
ch_seq (int): Chinese token' length
device (nn.Device)
num_heads (int, default: 8)
dropout_rate (float, default: 0.1)
Inputs:
en_tokens: (b, seq) LongTensor
ch_tokens: (b, seq) LongTensor
Outputs:
x: (b, seq, num_emb_ch) probability
'''
def __init__(self, num_emb_en, num_emb_ch, emb_dim, en_vocab, dim_in_transformer, ffn_hidden_dim, en_seq, ch_seq, device, num_heads=8, dropout_rate=0.1):
super(TranslatorModel, self).__init__()
# load glove word embedding
weight = self.get_glove_weight(en_vocab, num_emb_en, emb_dim)
self.en_word_embedding = nn.Embedding(num_emb_en, emb_dim)
self.en_word_embedding = self.en_word_embedding.from_pretrained(weight, freeze=True)
# chinese word embedding
self.ch_word_embedding = nn.Embedding(num_emb_ch, emb_dim)
# transformer
self.transformer = Transformer(
emb_dim=emb_dim,
dim=dim_in_transformer,
ffn_hidden_dim=ffn_hidden_dim,
encoder_seq=en_seq,
decoder_seq=ch_seq,
device=device,
num_heads=8,
dropout_rate=0.1
)
self.fc1 = nn.Linear(dim_in_transformer, 512)
self.fc2 = nn.Linear(512, 1024)
self.fc3 = nn.Linear(1024, num_emb_ch)
self.dropout = nn.Dropout(dropout_rate)
self.relu = nn.ReLU()
def forward(self, en_tokens, ch_tokens):
en_tokens = self.en_word_embedding(en_tokens)
ch_tokens = self.ch_word_embedding(ch_tokens)
x, score = self.transformer(en_tokens, ch_tokens)
x = self.relu(self.fc1(x))
x = self.dropout(x)
x = self.relu(self.fc2(x))
x = self.dropout(x)
x = self.fc3(x)
return x, score
def get_glove_weight(self, en_vocab, num_emb_en, emb_dim):
'''
Load embedding from GLOVE
Args:
en_vocab (torch.Vocab)
num_emb_en (int): (how many word in vocab)
emb_dim (int): word embedding's dim
'''
if os.path.isfile("data/word_embedding.pth"):
weight = torch.load("data/word_embedding.pth")
else:
weight = torch.randn((num_emb_en, emb_dim))
with open('data/glove.6B.100d.txt') as fp:
lines = fp.readlines()
for line in lines:
l = line.split(" ")
word = l[0]
emb = l[1:]
emb = torch.tensor([ float(i) for i in emb ])
if word in en_vocab:
weight[en_vocab[word]] = emb
torch.save(weight, "data/word_embedding.pth")
return weight