2019年5月28日 星期二

flask-based web interface deployment for pytorch chatbot

### folder structure and flask setup
> ls 
data/  pytorch_chatbot/  save/  templates/  web.py

> ls templates/
template.html

> conda install Flask

> python web.py
 * Serving Flask app "web" (lazy loading)
 * Environment: production
   WARNING: Do not use the development server in a production environment.
   Use a production WSGI server instead.
 * Debug mode: off
 * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)


<html>
<title>template.html</title>
<body>
<pre>
Test page for pytorch chatbot on seq2seq dataset
<form action='translate' method='post'>
model: <input type='text' name='model' value='{{param["model"]}}' />
epoch: <input type='text' name='epoch' value='{{param["epoch"]}}' />
topn: <input type='text' name='topn' value='{{param["topn"]}}' />
query: <input type='text' name='query' value='{{param["query"]}}'/>
<input type='submit' value='translate' />
</form>
{{param['result']}}
</pre>
</body>
</html>



##########################
# web.py
#    > python web.py
#########################
from flask import Flask, request, render_template

import torch
import random
import pytorch_chatbot.main as pcm
import pytorch_chatbot.evaluate as pce
from pytorch_chatbot.train import indexesFromSentence
from pytorch_chatbot.load import loadPrepareData
from pytorch_chatbot.model import nn, EncoderRNN, LuongAttnDecoderRNN

import subprocess
import json

def predictLoad(corpus, modelFile, n_layers=1, hidden_size=512):
  print('corpus={}\nmodelFile={}'.format(corpus,modelFile))

  torch.set_grad_enabled(False)
  voc, pairs = loadPrepareData(corpus)
  embedding = nn.Embedding(voc.n_words, hidden_size)
  encoder = EncoderRNN(voc.n_words, hidden_size, embedding, n_layers)
  attn_model = 'dot'
  decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.n_words, n_layers)

  checkpoint = torch.load(modelFile)
  encoder.load_state_dict(checkpoint['en'])
  decoder.load_state_dict(checkpoint['de'])

  # train mode set to false, effect only on dropout, batchNorm
  encoder.train(False)
  decoder.train(False)

  #try:
  encoder = encoder.to(device)
  decoder = decoder.to(device)
  #except:
  #  print('cannot get encoder/decoder')
  
  return encoder, decoder, voc

def predict(encoder, decoder, voc, question, top):
  result_list = []

  if(top==1):
    beam_size = 1
    output_words, _ = pce.evaluate(encoder, decoder, voc, question, beam_size)
    answer = ' '.join(output_words)
    answer = answer.replace('<EOS>','')
    result_list.append(answer)
    #print(output_words)
  else:
    beam_size = top
    output_words_list = pce.evaluate(encoder, decoder, voc, question, beam_size)
    count = 0;
    for output_words, score in output_words_list:
      count = count + 1
      if(count <= top):
        output_sentence = ' '.join(output_words)
        output_sentence = output_sentence.replace('<EOS>','')
        result_list.append(output_sentence)
        #print(" {:.3f} < {}".format(score, output_sentence))
  
  return result_list

def filter(voc, question):
  words = question.split()
  result = []
  for w in words:
    if(w in voc.word2index):
      result.append(w) 
  return ' '.join(result)

# -------------------------------
def sentence_test(voc,en,de,top,sentence):
    source = sentence.rstrip()
    seg_source = source
    fil_source = filter(voc, seg_source)
    target = predict(en, de, voc, fil_source, top)
    result = "\nsource: '%s'\nfilter: '%s'\n" % (seg_source,fil_source)
    
    for answer in target:
      result = result + "\t'%s'\n" % (answer)
    
    result = result + '\n'
    return result

def sentence_test_model(seg_corpus_name, iteration, top, sentence):
  n_layers = 1
  hidden_size = 512

  modelFile = home_path + 'save/model/' + seg_corpus_name + '/1-1_512/' + str(iteration) + '_backup_bidir_model.tar'

  en, de, voc = predictLoad(seg_corpus_name, modelFile, n_layers, hidden_size)

  return sentence_test(voc,en,de,top,sentence)

def file_test(voc,en,de,top,test_file_name):
  with open(test_file_name,"r") as f:
    jp_data = f.readlines()

  for i,source in enumerate(jp_data):
    source = source.rstrip()
    seg_source = source
    fil_source = filter(voc, seg_source)
    target = predict(en, de, voc, fil_source, top)
    print("%d:\nsource: '%s'\nfilter: '%s'" % (i+1,seg_source,fil_source))
    
    for answer in target:
      print("\t'%s'" % (answer))

def file_test_model(seg_corpus_name, iteration, top, test_file_name):
  n_layers = 1
  hidden_size = 512

  modelFile = home_path + 'save/model/' + seg_corpus_name + '/1-1_512/' + str(iteration) + '_backup_bidir_model.tar'

  en, de, voc = predictLoad(seg_corpus_name, modelFile, n_layers, hidden_size)

  file_test(voc,en,de,top,test_file_name)
  
def print_voc(voc):
  print('tw+jp voc size=%d' % (len(voc.word2index)))
  print(voc.index2word)

def list_models(seg_corpus_name=''):
  if seg_corpus_name=='':
    modelPath = home_path + 'save/model/'
  else:
    modelPath = home_path + 'save/model/' + seg_corpus_name + '/1-1_512'

  out_bytes = subprocess.check_output(['ls','-l',modelPath],
                                    stderr=subprocess.STDOUT)
  out_text = out_bytes.decode('utf-8')
  return out_text

def load_source(seg_corpus_name):
  path = home_path + 'data/' + seg_corpus_name + '.txt'
  
  with open(path) as inp:
    data = inp.readlines()

  print(len(data), len(data[0::2]), len(data[1::2]))

  data = { 'source': data[0::2], 'target': data[1::2] }
  return data

# --------------------------
app = Flask(__name__)

param0 = { 'model': 'translation2019_train_83k',
          'epoch':  6000,
          'topn' : 10,
          'query' : 'what time is it?',
          'result' : 'result area'
        }

@app.route('/')
def forms():
  return render_template('translate.html', param=param0)

@app.route('/translate/<model>/<int:epoch>/<int:topn>', methods=['GET', 'POST'])
def translate_long(model,epoch,topn):
  if request.method == 'POST':
    query = request.values['query']
  elif request.method == 'GET':
    query = request.args.get('query')

  return translate(model,epoch,topn,query)


@app.route('/translate', methods=['GET', 'POST'])
def translate_short():
  if request.method == 'POST':
    query = request.values['query']
    model = request.values['model']
    epoch = request.values['epoch']
    topn = request.values['topn']
  elif request.method == 'GET':
    query = request.args.get('query')
    model = request.args.get('model')
    epoch = request.args.get('epoch')
    topn = request.args.get('topn')

  return translate(model,epoch,topn,query)

def translate(model,epoch,topn,query):
  epoch = int(epoch)
  topn = int(topn)

  try:
    target = sentence_test_model(model,epoch,topn,query)
  except:
    target = 'internal error, retry a again'

  result = 'query="{}"\nresult="{}"\n'.format(query,target)
  
  param2 = { 'model': model,
          'epoch':  epoch,
          'topn' : topn,
          'query' : query,
          'result' : result
        }
  return render_template('translate.html', param=param2)

@app.route('/list/<model>')
def list_model(model):
  mlist = list_models(model)
  return '<pre>{}</pre>'.format(mlist)

@app.route('/list/')
def list():
  mlist = list_models()
  return '<pre>{}</pre>'.format(mlist)

#######################################

USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")
home_path = './'

if __name__ == '__main__':
  app.run(host='0.0.0.0',port=8080)


註: 本程式使用 GitHub JavaScript code prettifier 工具標示顏色。其方法如下:
   1.參考 [Blogger] 如何在 Blogger 顯示程式碼 - Google Code Prettify
     於【Blogger 版面配置 HTML/JavaScript小工具】安裝如下套件
       <script src="https://cdn.jsdelivr.net/gh/google/code-prettify@master/loader/run_prettify.js"></script>
   2.文章編輯再以HTML模式為程式包上如下標籤。
       <code class="prettyprint lang-html linenums"> ... </code>
       <code class="prettyprint lang-python linenums"> ... </code>

沒有留言: