### folder structure and flask setup > ls data/ pytorch_chatbot/ save/ templates/ web.py > ls templates/ template.html > conda install Flask > python web.py * Serving Flask app "web" (lazy loading) * Environment: production WARNING: Do not use the development server in a production environment. Use a production WSGI server instead. * Debug mode: off * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)<html> <title>template.html</title> <body> <pre> Test page for pytorch chatbot on seq2seq dataset <form action='translate' method='post'> model: <input type='text' name='model' value='{{param["model"]}}' /> epoch: <input type='text' name='epoch' value='{{param["epoch"]}}' /> topn: <input type='text' name='topn' value='{{param["topn"]}}' /> query: <input type='text' name='query' value='{{param["query"]}}'/> <input type='submit' value='translate' /> </form> {{param['result']}} </pre> </body> </html>
########################## # web.py # > python web.py ######################### from flask import Flask, request, render_template import torch import random import pytorch_chatbot.main as pcm import pytorch_chatbot.evaluate as pce from pytorch_chatbot.train import indexesFromSentence from pytorch_chatbot.load import loadPrepareData from pytorch_chatbot.model import nn, EncoderRNN, LuongAttnDecoderRNN import subprocess import json def predictLoad(corpus, modelFile, n_layers=1, hidden_size=512): print('corpus={}\nmodelFile={}'.format(corpus,modelFile)) torch.set_grad_enabled(False) voc, pairs = loadPrepareData(corpus) embedding = nn.Embedding(voc.n_words, hidden_size) encoder = EncoderRNN(voc.n_words, hidden_size, embedding, n_layers) attn_model = 'dot' decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.n_words, n_layers) checkpoint = torch.load(modelFile) encoder.load_state_dict(checkpoint['en']) decoder.load_state_dict(checkpoint['de']) # train mode set to false, effect only on dropout, batchNorm encoder.train(False) decoder.train(False) #try: encoder = encoder.to(device) decoder = decoder.to(device) #except: # print('cannot get encoder/decoder') return encoder, decoder, voc def predict(encoder, decoder, voc, question, top): result_list = [] if(top==1): beam_size = 1 output_words, _ = pce.evaluate(encoder, decoder, voc, question, beam_size) answer = ' '.join(output_words) answer = answer.replace('<EOS>','') result_list.append(answer) #print(output_words) else: beam_size = top output_words_list = pce.evaluate(encoder, decoder, voc, question, beam_size) count = 0; for output_words, score in output_words_list: count = count + 1 if(count <= top): output_sentence = ' '.join(output_words) output_sentence = output_sentence.replace('<EOS>','') result_list.append(output_sentence) #print(" {:.3f} < {}".format(score, output_sentence)) return result_list def filter(voc, question): words = question.split() result = [] for w in words: if(w in voc.word2index): result.append(w) return ' '.join(result) # ------------------------------- def sentence_test(voc,en,de,top,sentence): source = sentence.rstrip() seg_source = source fil_source = filter(voc, seg_source) target = predict(en, de, voc, fil_source, top) result = "\nsource: '%s'\nfilter: '%s'\n" % (seg_source,fil_source) for answer in target: result = result + "\t'%s'\n" % (answer) result = result + '\n' return result def sentence_test_model(seg_corpus_name, iteration, top, sentence): n_layers = 1 hidden_size = 512 modelFile = home_path + 'save/model/' + seg_corpus_name + '/1-1_512/' + str(iteration) + '_backup_bidir_model.tar' en, de, voc = predictLoad(seg_corpus_name, modelFile, n_layers, hidden_size) return sentence_test(voc,en,de,top,sentence) def file_test(voc,en,de,top,test_file_name): with open(test_file_name,"r") as f: jp_data = f.readlines() for i,source in enumerate(jp_data): source = source.rstrip() seg_source = source fil_source = filter(voc, seg_source) target = predict(en, de, voc, fil_source, top) print("%d:\nsource: '%s'\nfilter: '%s'" % (i+1,seg_source,fil_source)) for answer in target: print("\t'%s'" % (answer)) def file_test_model(seg_corpus_name, iteration, top, test_file_name): n_layers = 1 hidden_size = 512 modelFile = home_path + 'save/model/' + seg_corpus_name + '/1-1_512/' + str(iteration) + '_backup_bidir_model.tar' en, de, voc = predictLoad(seg_corpus_name, modelFile, n_layers, hidden_size) file_test(voc,en,de,top,test_file_name) def print_voc(voc): print('tw+jp voc size=%d' % (len(voc.word2index))) print(voc.index2word) def list_models(seg_corpus_name=''): if seg_corpus_name=='': modelPath = home_path + 'save/model/' else: modelPath = home_path + 'save/model/' + seg_corpus_name + '/1-1_512' out_bytes = subprocess.check_output(['ls','-l',modelPath], stderr=subprocess.STDOUT) out_text = out_bytes.decode('utf-8') return out_text def load_source(seg_corpus_name): path = home_path + 'data/' + seg_corpus_name + '.txt' with open(path) as inp: data = inp.readlines() print(len(data), len(data[0::2]), len(data[1::2])) data = { 'source': data[0::2], 'target': data[1::2] } return data # -------------------------- app = Flask(__name__) param0 = { 'model': 'translation2019_train_83k', 'epoch': 6000, 'topn' : 10, 'query' : 'what time is it?', 'result' : 'result area' } @app.route('/') def forms(): return render_template('translate.html', param=param0) @app.route('/translate/<model>/<int:epoch>/<int:topn>', methods=['GET', 'POST']) def translate_long(model,epoch,topn): if request.method == 'POST': query = request.values['query'] elif request.method == 'GET': query = request.args.get('query') return translate(model,epoch,topn,query) @app.route('/translate', methods=['GET', 'POST']) def translate_short(): if request.method == 'POST': query = request.values['query'] model = request.values['model'] epoch = request.values['epoch'] topn = request.values['topn'] elif request.method == 'GET': query = request.args.get('query') model = request.args.get('model') epoch = request.args.get('epoch') topn = request.args.get('topn') return translate(model,epoch,topn,query) def translate(model,epoch,topn,query): epoch = int(epoch) topn = int(topn) try: target = sentence_test_model(model,epoch,topn,query) except: target = 'internal error, retry a again' result = 'query="{}"\nresult="{}"\n'.format(query,target) param2 = { 'model': model, 'epoch': epoch, 'topn' : topn, 'query' : query, 'result' : result } return render_template('translate.html', param=param2) @app.route('/list/<model>') def list_model(model): mlist = list_models(model) return '<pre>{}</pre>'.format(mlist) @app.route('/list/') def list(): mlist = list_models() return '<pre>{}</pre>'.format(mlist) ####################################### USE_CUDA = torch.cuda.is_available() device = torch.device("cuda" if USE_CUDA else "cpu") home_path = './' if __name__ == '__main__': app.run(host='0.0.0.0',port=8080)
註: 本程式使用 GitHub JavaScript code prettifier 工具標示顏色。其方法如下: 1.參考 [Blogger] 如何在 Blogger 顯示程式碼 - Google Code Prettify 於【Blogger 版面配置 HTML/JavaScript小工具】安裝如下套件 <script src="https://cdn.jsdelivr.net/gh/google/code-prettify@master/loader/run_prettify.js"></script> 2.文章編輯再以HTML模式為程式包上如下標籤。 <code class="prettyprint lang-html linenums"> ... </code> <code class="prettyprint lang-python linenums"> ... </code>
2019年5月28日 星期二
flask-based web interface deployment for pytorch chatbot
訂閱:
張貼留言 (Atom)
沒有留言:
張貼留言