You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

255 lines
9.8 KiB

#!/usr/bin/env python
# (c) 2015--2018 NovoLanguage, author: David A. van Leeuwen
## Recognition interface for actual backend. Adapted from player.asr.debug.
import json
import sys
import wave
import requests
import websocket
import logging
import collections
import time
from .. import asr
logger = logging.getLogger(__name__)
## turn off annoying warnings
requests.packages.urllib3.disable_warnings()
logging.getLogger("requests.packages.urllib3.connectionpool").setLevel(logging.WARN)
buffer_size = 4096
gm = "gm.novolanguage.com" ## dev
protocol = "https"
port = 443
apiversion = 0
sessions = collections.Counter()
def segmentation(result):
"""converts a raw backend recognition result to a segment of novo.asr.segments class Segmentation"""
for w in result:
w["score"] = w["confidence"]["prob"]
w["llh"] = w["confidence"]["llr"]
w["label"] = w["label"]["raw"]
w["begin"] /= 10
w["end"] /= 10
for p in w["phones"]:
p["score"] = p["confidence"]["prob"]
p["llh"] = p["confidence"]["llr"]
p["begin"] /= 10
p["end"] /= 10
return asr.segments.Segmentation(result)
class rpcid:
id = 0
@staticmethod
def next():
rpcid.id += 1
return rpcid.id
class Recognizer(object):
def __init__(self, lang="en", gm=gm, grammar_version="0.1", user=None, password=None, snodeid=None, keepopen=False):
self.lang = lang
self.keepopen = keepopen
self.api_url = "%s://%s:%d/v%d" % (protocol, gm, port, apiversion)
self.verify = False
self.headers = {"Content-Type": "application/json"}
self.login_user(user, password)
data = {"l2": lang, "local": False, "skipupload": True}
if snodeid:
data["snodeid"] = snodeid
self.conn = None
self.init_session(data)
self.grammar_version = grammar_version
self.last_message = None
def login_user(self, username, password):
# obtain authentication token of user
logger.info('obtain auth token at %s', self.api_url)
data = {
'username': username,
'password': password
}
try:
r = requests.post(self.api_url + '/publishers/1/login', headers=self.headers, data=json.dumps(data), verify=self.verify)
except Exception as e:
logger.error("Cannot post request to GM API for user login: %s", e.message)
sys.exit(-1)
assert r.ok, r.reason
result = r.json()
if "errors" in result["response"]:
logger.info("Error in logging in: %s", result["response"]["errors"])
sys.exit(-1)
user_auth_token = result['response']['user']['authentication_token']
logger.info("User auth token is: %s", user_auth_token)
# set auth token in header
self.headers['Authentication-Token'] = user_auth_token
def init_session(self, data, direct=False, use_ip=False):
logger.info('Request new session: %s', data)
r = requests.post(self.api_url + '/sessions', headers=self.headers, data=json.dumps(data), verify=self.verify)
if not r.ok:
logger.error("New session request failed: %s", r.text)
return
status_url = r.headers.get("location")
if status_url:
## we got a redirect
status = {}
while True:
logger.debug("Checking %s", status_url)
s = requests.get(status_url, verify=self.verify)
if not s.ok:
logger.error('Checking Failed: %s', s.text)
return
status = s.json()
if status['status'] == 'PENDING':
logger.debug("Status: %s", status['status'])
time.sleep(1)
else:
break
session = status['result'][0] ## [1] is another status code...
if "error" in session:
logger.error("Error in getting a snode: %s", session["error"])
raise Exception
else:
session = r.json()
try:
logger.info("Session: %r", session)
if direct:
snode_ip = session["snode"]["ip"]
proxy_url = snode_ip
snode_port = session["port"]
ws_url = "%s://%s:%d/" % ("ws", snode_ip, snode_port)
else:
field = "ip" if use_ip else "hostname"
proxy_url = session['snode']['datacentre']['proxy'][field]
ws_url = 'wss://' + proxy_url + '/' + session['uuid']
logger.info("Connecting to websocket: %s", ws_url)
conn = websocket.create_connection(ws_url, sslopt={"check_hostname": self.verify})
logger.info("Connected.")
#except Exception, e:
except Exception as e:
logger.error("Unable to connect to websocket: %s", e.message)
raise e
self.session_id = session['id']
self.proxy_url = proxy_url
self.conn = conn
self.session = session
sessions[session["uuid"]] += 1
def setgrammar(self, grammar): ## backend grammar object: {"data": {...}, "type": "confusion_network"}
data = {"jsonrpc": "2.0",
'type': 'jsonrpc',
'method': 'set_grammar',
'params': grammar,
"id": rpcid.next()}
asr.spraaklab.schema.validate_rpc_grammar(grammar)
self.conn.send(json.dumps(data))
result = json.loads(self.conn.recv())
if result.get("error"):
logger.error("Exercise validation error: %s", result)
return result
def set_alternatives_grammar(self, *args, **kwargs):
if not "version" in kwargs:
kwargs["version"] = self.grammar_version
return self.setgrammar(alternatives_grammar(*args, **kwargs))
def recognize_wav(self, wavf):
w = wave.open(wavf, 'r')
nchannels, sampwidth, framerate, nframes, comptype, compname = w.getparams()
if nchannels > 1:
logging.error("Please use .wav with only 1 channel, found %d channels in %s", nchannels, wavf)
return
if (sampwidth != 2):
logging.error("Please use .wav with 2-byte PCM data, found %d bytes in %s", sampwidth, wavf)
return
if (framerate != 16000.0):
logging.error("Please use .wav sampled at 16000 Hz, found %1.0f in %s", framerate, wavf)
return
if (comptype != 'NONE'):
logging.error("Please use .wav with uncompressed data, found %s in %s", compname, wavf)
return
buf = w.readframes(nframes)
w.close()
return self.recognize_data(buf)
def recognize_data(self, buf):
nbytes_sent = 0
start = time.time()
for j in range(0, len(buf), buffer_size):
#audio_packet = str(buf[j:j + buffer_size])
audio_packet = buf[j:j + buffer_size]
nbytes_sent += len(audio_packet)
self.conn.send_binary(audio_packet)
self.conn.send(json.dumps({"jsonrpc": "2.0", "method": "get_result", "id": rpcid.next()}))
logger.info("Waiting for recognition result...")
self.last_message = self.conn.recv() ## keep result for the interested applications
message = json.loads(self.last_message)
dur = time.time() - start
logger.info("Recognition took %5.3f seconds", dur)
if "error" in message:
raise RuntimeError("Error from recognition backend: %r" % message.get("error"))
return segmentation(message["result"]["words"])
def recognize_url(self, url):
start = time.time()
data = json.dumps({"jsonrpc": "2.0", "method": "send_audio", "id": rpcid.next(), "params": {"type": "url", "data": url, "details": ["word", "utterance"]}})
self.conn.send(data)
logger.info("Waiting for recognition result...")
self.last_message = self.conn.recv() ## keep result for the interested applications
#print self.last_message
print(self.last_message)
message = json.loads(self.last_message)
dur = time.time() - start
logger.info("Recognition took %5.3f seconds", dur)
if "error" in message:
raise RuntimeError("Error from recognition backend: %r" % message.get("error"))
return segmentation(message["result"]["words"])
def __del__(self):
sessions[self.session["uuid"]] -= 1
if self.conn and sessions[self.session["uuid"]] <= 0:
self.conn.close()
url = self.api_url + '/sessions/%d' % self.session_id
if self.keepopen:
logger.info("Keeping session open...")
else:
logger.info("Closing session: %s", url)
r = requests.delete(url, headers=self.headers, verify=self.verify)
assert r.ok, r.reason
def alternatives_grammar(parts, version="0.1", ret=None):
"""Make a grammar of alternatives, as array(sequence)-of-array(alternatives)-of-strings"""
r = {"type": "confusion_network", "version": version}
if version=="0.1":
r["data"] = {"type": "multiple_choice", "parts": parts}
if isinstance(ret, list) and "dict" in ret:
r["return_dict"] = True
elif version=="1.0":
seqels = []
for part in parts:
altels = []
for alt in part:
words = alt.split(" ")
if len(words) > 1:
alt = {"kind": "sequence", "elements": words}
altels.append(alt)
seqels.append({"kind": "alternatives", "elements": altels})
r["data"] = {"kind": "sequence", "elements": seqels}
if isinstance(ret, list):
r["return_objects"] = ret
else:
raise ValueError("Unsupported version: %s" % version)
asr.spraaklab.schema.validate_rpc_grammar(r)
return r