#!/usr/bin/env python
# (c) 2015--2018 NovoLanguage, author: David A. van Leeuwen

## Recognition interface for actual backend.  Adapted from player.asr.debug.

import json
import sys
import wave
import requests
import websocket
import logging
import collections

import time

from .. import asr

logger = logging.getLogger(__name__)

## turn off annoying warnings
requests.packages.urllib3.disable_warnings()
logging.getLogger("requests.packages.urllib3.connectionpool").setLevel(logging.WARN)

buffer_size = 4096
gm = "gm.novolanguage.com" ## dev
protocol = "https"
port = 443
apiversion = 0

sessions = collections.Counter()

def segmentation(result):
    """converts a raw backend recognition result to a segment of novo.asr.segments class Segmentation"""
    for w in result:
        w["score"] = w["confidence"]["prob"]
        w["llh"] = w["confidence"]["llr"]
        w["label"] = w["label"]["raw"]
        w["begin"] /= 10
        w["end"] /= 10
        for p in w["phones"]:
            p["score"] = p["confidence"]["prob"]
            p["llh"] = p["confidence"]["llr"]
            p["begin"] /= 10
            p["end"] /= 10
    return asr.segments.Segmentation(result)

class rpcid:
    id = 0
    @staticmethod
    def next():
        rpcid.id += 1
        return rpcid.id

class Recognizer(object):
    def __init__(self, lang="en", gm=gm, grammar_version="0.1", user=None, password=None, snodeid=None, keepopen=False):
        self.lang = lang
        self.keepopen = keepopen
        self.api_url = "%s://%s:%d/v%d" % (protocol, gm, port, apiversion)
        self.verify = False
        self.headers = {"Content-Type": "application/json"}
        self.login_user(user, password)
        data = {"l2": lang, "local": False, "skipupload": True}
        if snodeid:
            data["snodeid"] = snodeid
        self.conn = None
        self.init_session(data)
        self.grammar_version = grammar_version
        self.last_message = None

    def login_user(self, username, password):
        # obtain authentication token of user
        logger.info('obtain auth token at %s', self.api_url)
        data = {
            'username': username,
            'password': password
        }
        try:
            r = requests.post(self.api_url + '/publishers/1/login', headers=self.headers, data=json.dumps(data), verify=self.verify)
        except Exception as e:
            logger.error("Cannot post request to GM API for user login: %s", e.message)
            sys.exit(-1)
        assert r.ok, r.reason
        result = r.json()
        if "errors" in result["response"]:
            logger.info("Error in logging in: %s", result["response"]["errors"])
            sys.exit(-1)

        user_auth_token = result['response']['user']['authentication_token']
        logger.info("User auth token is: %s", user_auth_token)

        # set auth token in header
        self.headers['Authentication-Token'] = user_auth_token

    def init_session(self, data, direct=False, use_ip=False):
        logger.info('Request new session: %s', data)
        r = requests.post(self.api_url + '/sessions', headers=self.headers, data=json.dumps(data), verify=self.verify)
        if not r.ok:
            logger.error("New session request failed: %s", r.text)
            return

        status_url = r.headers.get("location")
        if status_url:
            ## we got a redirect
            status = {}
            while True:
                logger.debug("Checking %s", status_url)
                s = requests.get(status_url, verify=self.verify)
                if not s.ok:
                    logger.error('Checking Failed: %s', s.text)
                    return

                status = s.json()
                if status['status'] == 'PENDING':
                    logger.debug("Status: %s", status['status'])
                    time.sleep(1)
                else:
                    break
            session = status['result'][0] ## [1] is another status code...
            if "error" in session:
                logger.error("Error in getting a snode: %s", session["error"])
                raise Exception
        else:
            session = r.json()

        try:
            logger.info("Session: %r", session)
            if direct:
                snode_ip = session["snode"]["ip"]
                proxy_url = snode_ip
                snode_port = session["port"]
                ws_url = "%s://%s:%d/" % ("ws", snode_ip, snode_port)
            else:
                field = "ip" if use_ip else "hostname"
                proxy_url = session['snode']['datacentre']['proxy'][field]
                ws_url = 'wss://' + proxy_url + '/' + session['uuid']
            logger.info("Connecting to websocket: %s", ws_url)
            conn = websocket.create_connection(ws_url, sslopt={"check_hostname": self.verify})
            logger.info("Connected.")
        #except Exception, e:
        except Exception as e:
            logger.error("Unable to connect to websocket: %s", e.message)
            raise e

        self.session_id = session['id']
        self.proxy_url = proxy_url
        self.conn = conn
        self.session = session
        sessions[session["uuid"]] += 1

    def setgrammar(self, grammar): ## backend grammar object: {"data": {...}, "type": "confusion_network"}
        data = {"jsonrpc": "2.0",
                'type': 'jsonrpc',
                'method': 'set_grammar',
                'params': grammar,
                "id": rpcid.next()}
        asr.spraaklab.schema.validate_rpc_grammar(grammar)
        self.conn.send(json.dumps(data))
        result = json.loads(self.conn.recv())
        if result.get("error"):
            logger.error("Exercise validation error: %s", result)
        return result

    def set_alternatives_grammar(self, *args, **kwargs):
        if not "version" in kwargs:
            kwargs["version"] = self.grammar_version
        return self.setgrammar(alternatives_grammar(*args, **kwargs))

    def recognize_wav(self, wavf):
        w = wave.open(wavf, 'r')
        nchannels, sampwidth, framerate, nframes, comptype, compname = w.getparams()
        if nchannels > 1:
            logging.error("Please use .wav with only 1 channel, found %d channels in %s", nchannels, wavf)
            return
        if (sampwidth != 2):
            logging.error("Please use .wav with 2-byte PCM data, found %d bytes in %s", sampwidth, wavf)
            return
        if (framerate != 16000.0):
            logging.error("Please use .wav sampled at 16000 Hz, found %1.0f in %s", framerate, wavf)
            return
        if (comptype != 'NONE'):
            logging.error("Please use .wav with uncompressed data, found %s in %s", compname, wavf)
            return
        buf = w.readframes(nframes)
        w.close()
        return self.recognize_data(buf)

    def recognize_data(self, buf):
        nbytes_sent = 0
        start = time.time()
        for j in range(0, len(buf), buffer_size):
            #audio_packet = str(buf[j:j + buffer_size])
            audio_packet = buf[j:j + buffer_size]
            nbytes_sent += len(audio_packet)
            self.conn.send_binary(audio_packet)
        self.conn.send(json.dumps({"jsonrpc": "2.0", "method": "get_result", "id": rpcid.next()}))
        logger.info("Waiting for recognition result...")
        self.last_message = self.conn.recv() ## keep result for the interested applications
        message = json.loads(self.last_message)
        dur = time.time() - start
        logger.info("Recognition took %5.3f seconds", dur)
        if "error" in message:
            raise RuntimeError("Error from recognition backend: %r" % message.get("error"))
        return segmentation(message["result"]["words"])

    def recognize_url(self, url):
        start = time.time()
        data = json.dumps({"jsonrpc": "2.0", "method": "send_audio", "id": rpcid.next(), "params": {"type": "url", "data": url, "details": ["word", "utterance"]}})
        self.conn.send(data)
        logger.info("Waiting for recognition result...")
        self.last_message = self.conn.recv() ## keep result for the interested applications
        #print self.last_message
        print(self.last_message)
        message = json.loads(self.last_message)
        dur = time.time() - start
        logger.info("Recognition took %5.3f seconds", dur)
        if "error" in message:
            raise RuntimeError("Error from recognition backend: %r" % message.get("error"))
        return segmentation(message["result"]["words"])

    def __del__(self):
        sessions[self.session["uuid"]] -= 1
        if self.conn and sessions[self.session["uuid"]] <= 0:
            self.conn.close()
            url = self.api_url + '/sessions/%d' % self.session_id
            if self.keepopen:
                logger.info("Keeping session open...")
            else:
                logger.info("Closing session: %s", url)
                r = requests.delete(url, headers=self.headers, verify=self.verify)
                assert r.ok, r.reason

def alternatives_grammar(parts, version="0.1", ret=None):
    """Make a grammar of alternatives, as array(sequence)-of-array(alternatives)-of-strings"""
    r = {"type": "confusion_network", "version": version}
    if version=="0.1":
        r["data"] = {"type": "multiple_choice", "parts": parts}
        if isinstance(ret, list) and "dict" in ret:
            r["return_dict"] = True
    elif version=="1.0":
        seqels = []
        for part in parts:
            altels = []
            for alt in part:
                words = alt.split(" ")
                if len(words) > 1:
                    alt = {"kind": "sequence", "elements": words}
                altels.append(alt)
            seqels.append({"kind": "alternatives", "elements": altels})
        r["data"] = {"kind": "sequence", "elements": seqels}
        if isinstance(ret, list):
            r["return_objects"] = ret
    else:
        raise ValueError("Unsupported version: %s" % version)
    asr.spraaklab.schema.validate_rpc_grammar(r)
    return r