### folder structure and flask setup
> ls
data/ pytorch_chatbot/ save/ templates/ web.py
> ls templates/
template.html
> conda install Flask
> python web.py
* Serving Flask app "web" (lazy loading)
* Environment: production
WARNING: Do not use the development server in a production environment.
Use a production WSGI server instead.
* Debug mode: off
* Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)
-
- <html>
- <title>template.html</title>
- <body>
- <pre>
- Test page for pytorch chatbot on seq2seq dataset
- <form action='translate' method='post'>
- model: <input type='text' name='model' value='{{param["model"]}}' />
- epoch: <input type='text' name='epoch' value='{{param["epoch"]}}' />
- topn: <input type='text' name='topn' value='{{param["topn"]}}' />
- query: <input type='text' name='query' value='{{param["query"]}}'/>
- <input type='submit' value='translate' />
- </form>
- {{param['result']}}
- </pre>
- </body>
- </html>
-
- ##########################
- # web.py
- # > python web.py
- #########################
- from flask import Flask, request, render_template
-
- import torch
- import random
- import pytorch_chatbot.main as pcm
- import pytorch_chatbot.evaluate as pce
- from pytorch_chatbot.train import indexesFromSentence
- from pytorch_chatbot.load import loadPrepareData
- from pytorch_chatbot.model import nn, EncoderRNN, LuongAttnDecoderRNN
-
- import subprocess
- import json
-
- def predictLoad(corpus, modelFile, n_layers=1, hidden_size=512):
- print('corpus={}\nmodelFile={}'.format(corpus,modelFile))
-
- torch.set_grad_enabled(False)
- voc, pairs = loadPrepareData(corpus)
- embedding = nn.Embedding(voc.n_words, hidden_size)
- encoder = EncoderRNN(voc.n_words, hidden_size, embedding, n_layers)
- attn_model = 'dot'
- decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.n_words, n_layers)
-
- checkpoint = torch.load(modelFile)
- encoder.load_state_dict(checkpoint['en'])
- decoder.load_state_dict(checkpoint['de'])
-
- # train mode set to false, effect only on dropout, batchNorm
- encoder.train(False)
- decoder.train(False)
-
- #try:
- encoder = encoder.to(device)
- decoder = decoder.to(device)
- #except:
- # print('cannot get encoder/decoder')
-
- return encoder, decoder, voc
-
- def predict(encoder, decoder, voc, question, top):
- result_list = []
-
- if(top==1):
- beam_size = 1
- output_words, _ = pce.evaluate(encoder, decoder, voc, question, beam_size)
- answer = ' '.join(output_words)
- answer = answer.replace('<EOS>','')
- result_list.append(answer)
- #print(output_words)
- else:
- beam_size = top
- output_words_list = pce.evaluate(encoder, decoder, voc, question, beam_size)
- count = 0;
- for output_words, score in output_words_list:
- count = count + 1
- if(count <= top):
- output_sentence = ' '.join(output_words)
- output_sentence = output_sentence.replace('<EOS>','')
- result_list.append(output_sentence)
- #print(" {:.3f} < {}".format(score, output_sentence))
-
- return result_list
-
- def filter(voc, question):
- words = question.split()
- result = []
- for w in words:
- if(w in voc.word2index):
- result.append(w)
- return ' '.join(result)
-
- # -------------------------------
- def sentence_test(voc,en,de,top,sentence):
- source = sentence.rstrip()
- seg_source = source
- fil_source = filter(voc, seg_source)
- target = predict(en, de, voc, fil_source, top)
- result = "\nsource: '%s'\nfilter: '%s'\n" % (seg_source,fil_source)
-
- for answer in target:
- result = result + "\t'%s'\n" % (answer)
-
- result = result + '\n'
- return result
-
- def sentence_test_model(seg_corpus_name, iteration, top, sentence):
- n_layers = 1
- hidden_size = 512
-
- modelFile = home_path + 'save/model/' + seg_corpus_name + '/1-1_512/' + str(iteration) + '_backup_bidir_model.tar'
-
- en, de, voc = predictLoad(seg_corpus_name, modelFile, n_layers, hidden_size)
-
- return sentence_test(voc,en,de,top,sentence)
-
- def file_test(voc,en,de,top,test_file_name):
- with open(test_file_name,"r") as f:
- jp_data = f.readlines()
-
- for i,source in enumerate(jp_data):
- source = source.rstrip()
- seg_source = source
- fil_source = filter(voc, seg_source)
- target = predict(en, de, voc, fil_source, top)
- print("%d:\nsource: '%s'\nfilter: '%s'" % (i+1,seg_source,fil_source))
-
- for answer in target:
- print("\t'%s'" % (answer))
-
- def file_test_model(seg_corpus_name, iteration, top, test_file_name):
- n_layers = 1
- hidden_size = 512
-
- modelFile = home_path + 'save/model/' + seg_corpus_name + '/1-1_512/' + str(iteration) + '_backup_bidir_model.tar'
-
- en, de, voc = predictLoad(seg_corpus_name, modelFile, n_layers, hidden_size)
-
- file_test(voc,en,de,top,test_file_name)
-
- def print_voc(voc):
- print('tw+jp voc size=%d' % (len(voc.word2index)))
- print(voc.index2word)
-
- def list_models(seg_corpus_name=''):
- if seg_corpus_name=='':
- modelPath = home_path + 'save/model/'
- else:
- modelPath = home_path + 'save/model/' + seg_corpus_name + '/1-1_512'
-
- out_bytes = subprocess.check_output(['ls','-l',modelPath],
- stderr=subprocess.STDOUT)
- out_text = out_bytes.decode('utf-8')
- return out_text
-
- def load_source(seg_corpus_name):
- path = home_path + 'data/' + seg_corpus_name + '.txt'
-
- with open(path) as inp:
- data = inp.readlines()
-
- print(len(data), len(data[0::2]), len(data[1::2]))
-
- data = { 'source': data[0::2], 'target': data[1::2] }
- return data
-
- # --------------------------
- app = Flask(__name__)
-
- param0 = { 'model': 'translation2019_train_83k',
- 'epoch': 6000,
- 'topn' : 10,
- 'query' : 'what time is it?',
- 'result' : 'result area'
- }
-
- @app.route('/')
- def forms():
- return render_template('translate.html', param=param0)
-
- @app.route('/translate/<model>/<int:epoch>/<int:topn>', methods=['GET', 'POST'])
- def translate_long(model,epoch,topn):
- if request.method == 'POST':
- query = request.values['query']
- elif request.method == 'GET':
- query = request.args.get('query')
-
- return translate(model,epoch,topn,query)
-
-
- @app.route('/translate', methods=['GET', 'POST'])
- def translate_short():
- if request.method == 'POST':
- query = request.values['query']
- model = request.values['model']
- epoch = request.values['epoch']
- topn = request.values['topn']
- elif request.method == 'GET':
- query = request.args.get('query')
- model = request.args.get('model')
- epoch = request.args.get('epoch')
- topn = request.args.get('topn')
-
- return translate(model,epoch,topn,query)
-
- def translate(model,epoch,topn,query):
- epoch = int(epoch)
- topn = int(topn)
-
- try:
- target = sentence_test_model(model,epoch,topn,query)
- except:
- target = 'internal error, retry a again'
-
- result = 'query="{}"\nresult="{}"\n'.format(query,target)
-
- param2 = { 'model': model,
- 'epoch': epoch,
- 'topn' : topn,
- 'query' : query,
- 'result' : result
- }
- return render_template('translate.html', param=param2)
-
- @app.route('/list/<model>')
- def list_model(model):
- mlist = list_models(model)
- return '<pre>{}</pre>'.format(mlist)
-
- @app.route('/list/')
- def list():
- mlist = list_models()
- return '<pre>{}</pre>'.format(mlist)
-
- #######################################
-
- USE_CUDA = torch.cuda.is_available()
- device = torch.device("cuda" if USE_CUDA else "cpu")
- home_path = './'
-
- if __name__ == '__main__':
- app.run(host='0.0.0.0',port=8080)
註: 本程式使用 GitHub JavaScript code prettifier 工具標示顏色。其方法如下:
1.參考 [Blogger] 如何在 Blogger 顯示程式碼 - Google Code Prettify
於【Blogger 版面配置 HTML/JavaScript小工具】安裝如下套件
<script src="https://cdn.jsdelivr.net/gh/google/code-prettify@master/loader/run_prettify.js"></script>
2.文章編輯再以HTML模式為程式包上如下標籤。
<code class="prettyprint lang-html linenums"> ... </code>
<code class="prettyprint lang-python linenums"> ... </code>