255 lines
9.8 KiB
Python
255 lines
9.8 KiB
Python
|
#!/usr/bin/env python
|
||
|
# (c) 2015--2018 NovoLanguage, author: David A. van Leeuwen
|
||
|
|
||
|
## Recognition interface for actual backend. Adapted from player.asr.debug.
|
||
|
|
||
|
import json
|
||
|
import sys
|
||
|
import wave
|
||
|
import requests
|
||
|
import websocket
|
||
|
import logging
|
||
|
import collections
|
||
|
|
||
|
import time
|
||
|
|
||
|
from .. import asr
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
## turn off annoying warnings
|
||
|
requests.packages.urllib3.disable_warnings()
|
||
|
logging.getLogger("requests.packages.urllib3.connectionpool").setLevel(logging.WARN)
|
||
|
|
||
|
buffer_size = 4096
|
||
|
gm = "gm.novolanguage.com" ## dev
|
||
|
protocol = "https"
|
||
|
port = 443
|
||
|
apiversion = 0
|
||
|
|
||
|
sessions = collections.Counter()
|
||
|
|
||
|
def segmentation(result):
|
||
|
"""converts a raw backend recognition result to a segment of novo.asr.segments class Segmentation"""
|
||
|
for w in result:
|
||
|
w["score"] = w["confidence"]["prob"]
|
||
|
w["llh"] = w["confidence"]["llr"]
|
||
|
w["label"] = w["label"]["raw"]
|
||
|
w["begin"] /= 10
|
||
|
w["end"] /= 10
|
||
|
for p in w["phones"]:
|
||
|
p["score"] = p["confidence"]["prob"]
|
||
|
p["llh"] = p["confidence"]["llr"]
|
||
|
p["begin"] /= 10
|
||
|
p["end"] /= 10
|
||
|
return asr.segments.Segmentation(result)
|
||
|
|
||
|
class rpcid:
|
||
|
id = 0
|
||
|
@staticmethod
|
||
|
def next():
|
||
|
rpcid.id += 1
|
||
|
return rpcid.id
|
||
|
|
||
|
class Recognizer(object):
|
||
|
def __init__(self, lang="en", gm=gm, grammar_version="0.1", user=None, password=None, snodeid=None, keepopen=False):
|
||
|
self.lang = lang
|
||
|
self.keepopen = keepopen
|
||
|
self.api_url = "%s://%s:%d/v%d" % (protocol, gm, port, apiversion)
|
||
|
self.verify = False
|
||
|
self.headers = {"Content-Type": "application/json"}
|
||
|
self.login_user(user, password)
|
||
|
data = {"l2": lang, "local": False, "skipupload": True}
|
||
|
if snodeid:
|
||
|
data["snodeid"] = snodeid
|
||
|
self.conn = None
|
||
|
self.init_session(data)
|
||
|
self.grammar_version = grammar_version
|
||
|
self.last_message = None
|
||
|
|
||
|
def login_user(self, username, password):
|
||
|
# obtain authentication token of user
|
||
|
logger.info('obtain auth token at %s', self.api_url)
|
||
|
data = {
|
||
|
'username': username,
|
||
|
'password': password
|
||
|
}
|
||
|
try:
|
||
|
r = requests.post(self.api_url + '/publishers/1/login', headers=self.headers, data=json.dumps(data), verify=self.verify)
|
||
|
except Exception as e:
|
||
|
logger.error("Cannot post request to GM API for user login: %s", e.message)
|
||
|
sys.exit(-1)
|
||
|
assert r.ok, r.reason
|
||
|
result = r.json()
|
||
|
if "errors" in result["response"]:
|
||
|
logger.info("Error in logging in: %s", result["response"]["errors"])
|
||
|
sys.exit(-1)
|
||
|
|
||
|
user_auth_token = result['response']['user']['authentication_token']
|
||
|
logger.info("User auth token is: %s", user_auth_token)
|
||
|
|
||
|
# set auth token in header
|
||
|
self.headers['Authentication-Token'] = user_auth_token
|
||
|
|
||
|
def init_session(self, data, direct=False, use_ip=False):
|
||
|
logger.info('Request new session: %s', data)
|
||
|
r = requests.post(self.api_url + '/sessions', headers=self.headers, data=json.dumps(data), verify=self.verify)
|
||
|
if not r.ok:
|
||
|
logger.error("New session request failed: %s", r.text)
|
||
|
return
|
||
|
|
||
|
status_url = r.headers.get("location")
|
||
|
if status_url:
|
||
|
## we got a redirect
|
||
|
status = {}
|
||
|
while True:
|
||
|
logger.debug("Checking %s", status_url)
|
||
|
s = requests.get(status_url, verify=self.verify)
|
||
|
if not s.ok:
|
||
|
logger.error('Checking Failed: %s', s.text)
|
||
|
return
|
||
|
|
||
|
status = s.json()
|
||
|
if status['status'] == 'PENDING':
|
||
|
logger.debug("Status: %s", status['status'])
|
||
|
time.sleep(1)
|
||
|
else:
|
||
|
break
|
||
|
session = status['result'][0] ## [1] is another status code...
|
||
|
if "error" in session:
|
||
|
logger.error("Error in getting a snode: %s", session["error"])
|
||
|
raise Exception
|
||
|
else:
|
||
|
session = r.json()
|
||
|
|
||
|
try:
|
||
|
logger.info("Session: %r", session)
|
||
|
if direct:
|
||
|
snode_ip = session["snode"]["ip"]
|
||
|
proxy_url = snode_ip
|
||
|
snode_port = session["port"]
|
||
|
ws_url = "%s://%s:%d/" % ("ws", snode_ip, snode_port)
|
||
|
else:
|
||
|
field = "ip" if use_ip else "hostname"
|
||
|
proxy_url = session['snode']['datacentre']['proxy'][field]
|
||
|
ws_url = 'wss://' + proxy_url + '/' + session['uuid']
|
||
|
logger.info("Connecting to websocket: %s", ws_url)
|
||
|
conn = websocket.create_connection(ws_url, sslopt={"check_hostname": self.verify})
|
||
|
logger.info("Connected.")
|
||
|
#except Exception, e:
|
||
|
except Exception as e:
|
||
|
logger.error("Unable to connect to websocket: %s", e.message)
|
||
|
raise e
|
||
|
|
||
|
self.session_id = session['id']
|
||
|
self.proxy_url = proxy_url
|
||
|
self.conn = conn
|
||
|
self.session = session
|
||
|
sessions[session["uuid"]] += 1
|
||
|
|
||
|
def setgrammar(self, grammar): ## backend grammar object: {"data": {...}, "type": "confusion_network"}
|
||
|
data = {"jsonrpc": "2.0",
|
||
|
'type': 'jsonrpc',
|
||
|
'method': 'set_grammar',
|
||
|
'params': grammar,
|
||
|
"id": rpcid.next()}
|
||
|
asr.spraaklab.schema.validate_rpc_grammar(grammar)
|
||
|
self.conn.send(json.dumps(data))
|
||
|
result = json.loads(self.conn.recv())
|
||
|
if result.get("error"):
|
||
|
logger.error("Exercise validation error: %s", result)
|
||
|
return result
|
||
|
|
||
|
def set_alternatives_grammar(self, *args, **kwargs):
|
||
|
if not "version" in kwargs:
|
||
|
kwargs["version"] = self.grammar_version
|
||
|
return self.setgrammar(alternatives_grammar(*args, **kwargs))
|
||
|
|
||
|
def recognize_wav(self, wavf):
|
||
|
w = wave.open(wavf, 'r')
|
||
|
nchannels, sampwidth, framerate, nframes, comptype, compname = w.getparams()
|
||
|
if nchannels > 1:
|
||
|
logging.error("Please use .wav with only 1 channel, found %d channels in %s", nchannels, wavf)
|
||
|
return
|
||
|
if (sampwidth != 2):
|
||
|
logging.error("Please use .wav with 2-byte PCM data, found %d bytes in %s", sampwidth, wavf)
|
||
|
return
|
||
|
if (framerate != 16000.0):
|
||
|
logging.error("Please use .wav sampled at 16000 Hz, found %1.0f in %s", framerate, wavf)
|
||
|
return
|
||
|
if (comptype != 'NONE'):
|
||
|
logging.error("Please use .wav with uncompressed data, found %s in %s", compname, wavf)
|
||
|
return
|
||
|
buf = w.readframes(nframes)
|
||
|
w.close()
|
||
|
return self.recognize_data(buf)
|
||
|
|
||
|
def recognize_data(self, buf):
|
||
|
nbytes_sent = 0
|
||
|
start = time.time()
|
||
|
for j in range(0, len(buf), buffer_size):
|
||
|
audio_packet = str(buf[j:j + buffer_size])
|
||
|
nbytes_sent += len(audio_packet)
|
||
|
self.conn.send_binary(audio_packet)
|
||
|
self.conn.send(json.dumps({"jsonrpc": "2.0", "method": "get_result", "id": rpcid.next()}))
|
||
|
logger.info("Waiting for recognition result...")
|
||
|
self.last_message = self.conn.recv() ## keep result for the interested applications
|
||
|
message = json.loads(self.last_message)
|
||
|
dur = time.time() - start
|
||
|
logger.info("Recognition took %5.3f seconds", dur)
|
||
|
if "error" in message:
|
||
|
raise RuntimeError("Error from recognition backend: %r" % message.get("error"))
|
||
|
return segmentation(message["result"]["words"])
|
||
|
|
||
|
def recognize_url(self, url):
|
||
|
start = time.time()
|
||
|
data = json.dumps({"jsonrpc": "2.0", "method": "send_audio", "id": rpcid.next(), "params": {"type": "url", "data": url, "details": ["word", "utterance"]}})
|
||
|
self.conn.send(data)
|
||
|
logger.info("Waiting for recognition result...")
|
||
|
self.last_message = self.conn.recv() ## keep result for the interested applications
|
||
|
#print self.last_message
|
||
|
print(self.last_message)
|
||
|
message = json.loads(self.last_message)
|
||
|
dur = time.time() - start
|
||
|
logger.info("Recognition took %5.3f seconds", dur)
|
||
|
if "error" in message:
|
||
|
raise RuntimeError("Error from recognition backend: %r" % message.get("error"))
|
||
|
return segmentation(message["result"]["words"])
|
||
|
|
||
|
def __del__(self):
|
||
|
sessions[self.session["uuid"]] -= 1
|
||
|
if self.conn and sessions[self.session["uuid"]] <= 0:
|
||
|
self.conn.close()
|
||
|
url = self.api_url + '/sessions/%d' % self.session_id
|
||
|
if self.keepopen:
|
||
|
logger.info("Keeping session open...")
|
||
|
else:
|
||
|
logger.info("Closing session: %s", url)
|
||
|
r = requests.delete(url, headers=self.headers, verify=self.verify)
|
||
|
assert r.ok, r.reason
|
||
|
|
||
|
def alternatives_grammar(parts, version="0.1", ret=None):
|
||
|
"""Make a grammar of alternatives, as array(sequence)-of-array(alternatives)-of-strings"""
|
||
|
r = {"type": "confusion_network", "version": version}
|
||
|
if version=="0.1":
|
||
|
r["data"] = {"type": "multiple_choice", "parts": parts}
|
||
|
if isinstance(ret, list) and "dict" in ret:
|
||
|
r["return_dict"] = True
|
||
|
elif version=="1.0":
|
||
|
seqels = []
|
||
|
for part in parts:
|
||
|
altels = []
|
||
|
for alt in part:
|
||
|
words = alt.split(" ")
|
||
|
if len(words) > 1:
|
||
|
alt = {"kind": "sequence", "elements": words}
|
||
|
altels.append(alt)
|
||
|
seqels.append({"kind": "alternatives", "elements": altels})
|
||
|
r["data"] = {"kind": "sequence", "elements": seqels}
|
||
|
if isinstance(ret, list):
|
||
|
r["return_objects"] = ret
|
||
|
else:
|
||
|
raise ValueError("Unsupported version: %s" % version)
|
||
|
asr.spraaklab.schema.validate_rpc_grammar(r)
|
||
|
return r
|