feat: PyTorch 模型 & Telegram Bot
This commit is contained in:
parent
5b89104cfa
commit
75a3545e2d
41
Bot/PTTBotModel.py
Normal file
41
Bot/PTTBotModel.py
Normal file
@ -0,0 +1,41 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
class PTTBotModel(nn.Module):
|
||||
def __init__(self, tokenizer):
|
||||
super(PTTBotModel, self).__init__()
|
||||
self.embedding = nn.Embedding(num_embeddings=len(tokenizer), embedding_dim=300, padding_idx=2)
|
||||
self.encoder = nn.LSTM(input_size=300, hidden_size=1024, batch_first=True)
|
||||
self.decoder = nn.LSTM(input_size=300, hidden_size=1024, batch_first=True)
|
||||
self.dense = nn.Linear(in_features=1024, out_features=len(tokenizer))
|
||||
self.dropout03 = nn.Dropout(0.3)
|
||||
self.dropout05 = nn.Dropout(0.5)
|
||||
|
||||
def forward(self, q, a, lenq, lena):
|
||||
# (BATCH, 40)
|
||||
q = self.embedding(q)
|
||||
# (BATCH, 40, 300)
|
||||
|
||||
# (BATCH, 30)
|
||||
a = self.embedding(a)
|
||||
# (BATCH, 30, 300)
|
||||
|
||||
q = self.dropout03(q)
|
||||
a = self.dropout03(a)
|
||||
|
||||
q = nn.utils.rnn.pack_padded_sequence(q, lenq, batch_first=True, enforce_sorted=False)
|
||||
out, (h, c) = self.encoder(q)
|
||||
|
||||
|
||||
h = self.dropout03(h)
|
||||
c = self.dropout03(c)
|
||||
# (BATCH, 30, 512) for output
|
||||
# (1, BATCH, 512) for h (最後一個 hidden state)
|
||||
# (1, BATCH, 512) for c (最後一個 cell state)
|
||||
out, (h, c) = self.decoder(a, (h, c))
|
||||
|
||||
# (BATCH, 30, 1024) for output
|
||||
out = self.dense( self.dropout05(out) )
|
||||
|
||||
return out
|
||||
81
Bot/bot.py
Normal file
81
Bot/bot.py
Normal file
@ -0,0 +1,81 @@
|
||||
from urllib import response
|
||||
from cairo import Filter
|
||||
from telegram import Update
|
||||
import torch
|
||||
from PTTBotModel import PTTBotModel
|
||||
from tokenizer import Tokenizer
|
||||
from telegram.ext import Updater, MessageHandler, Filters, CallbackContext
|
||||
|
||||
TOKEN = ""
|
||||
|
||||
class PTTBot():
|
||||
def __init__(self):
|
||||
self.updater = Updater(token=TOKEN, use_context=True)
|
||||
self.dispacher = self.updater.dispatcher
|
||||
self.dispacher.add_handler(MessageHandler(Filters.text, self.response))
|
||||
self.loadModel()
|
||||
|
||||
def start(self):
|
||||
self.updater.start_polling()
|
||||
print("start...")
|
||||
|
||||
def loadModel(self):
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Using {self.device} device")
|
||||
|
||||
print("[INFO] 載入 tokenizer")
|
||||
self.tokenizer = Tokenizer()
|
||||
|
||||
print("[INFO] 載入 model")
|
||||
self.model = PTTBotModel(self.tokenizer).to(self.device)
|
||||
self.model.load_state_dict(torch.load('./char_based_state.pt'))
|
||||
|
||||
def predict(self, q):
|
||||
self.model.eval()
|
||||
QTokens = self.tokenizer.chars_to_tokens(q)
|
||||
QTokens = [ self.tokenizer.char_index['<sos>'] ] + QTokens + [ self.tokenizer.char_index['<end>'] ]
|
||||
lenq = len(QTokens)
|
||||
QTokens += [ self.tokenizer.char_index['<pad>'] ] * (40-len(QTokens))
|
||||
|
||||
QTokens = torch.tensor(QTokens).unsqueeze(0)
|
||||
lenq = torch.tensor([lenq])
|
||||
lena = torch.tensor([])
|
||||
|
||||
ans = []
|
||||
with torch.no_grad():
|
||||
QTokens = QTokens.to(self.device)
|
||||
|
||||
a = [0]*30
|
||||
a[0] = 0 # <sos>
|
||||
a = torch.tensor(a).unsqueeze(0).to(self.device)
|
||||
|
||||
for i in range(30):
|
||||
out = self.model(QTokens, a, lenq, lena)
|
||||
out = out.view(-1, out.shape[-1])
|
||||
out_predict = torch.argmax(out, dim=1)
|
||||
if i!=29:
|
||||
a[0][i+1] = out_predict[i]
|
||||
|
||||
words = self.tokenizer.tokens_to_chars(a[0])
|
||||
for word in words:
|
||||
if word not in ['<sos>', '<pad>', '<end>']:
|
||||
ans.append(word)
|
||||
return ans
|
||||
|
||||
|
||||
def response(self, update, context):
|
||||
user = update.effective_chat
|
||||
q = update.message.text
|
||||
|
||||
prediction = self.predict(q)
|
||||
|
||||
msg = "".join(prediction)
|
||||
context.bot.send_message(
|
||||
chat_id=user.id,
|
||||
text=msg
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
bot = PTTBot()
|
||||
bot.start()
|
||||
29
Bot/tokenizer.py
Normal file
29
Bot/tokenizer.py
Normal file
@ -0,0 +1,29 @@
|
||||
import torch
|
||||
import pickle
|
||||
|
||||
class Tokenizer():
|
||||
def __init__(self):
|
||||
with open('./char_index.pkl', 'rb') as fp:
|
||||
self.char_index = pickle.load(fp)
|
||||
with open('./index_char.pkl', 'rb') as fp:
|
||||
self.index_char = pickle.load(fp)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.char_index)
|
||||
|
||||
def chars_to_tokens(self, chars):
|
||||
ans = []
|
||||
for ch in chars:
|
||||
if ch in self.char_index:
|
||||
ans.append( self.char_index[ch] )
|
||||
else:
|
||||
ans.append( self.char_index['<unk>'] )
|
||||
return ans
|
||||
|
||||
def tokens_to_chars(self, tokens):
|
||||
if type(tokens) is torch.Tensor:
|
||||
tokens = tokens.tolist()
|
||||
ans = []
|
||||
for token in tokens:
|
||||
ans.append( self.index_char[token] )
|
||||
return ans
|
||||
Loading…
Reference in New Issue
Block a user