Skip to content

Commit

Permalink
Add wake HTTP server
Browse files Browse the repository at this point in the history
  • Loading branch information
synesthesiam committed Feb 26, 2024
1 parent f54e98e commit 1166e56
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 72 deletions.
2 changes: 1 addition & 1 deletion wyoming/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.5.3
1.5.4
42 changes: 8 additions & 34 deletions wyoming/http/asr_server.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,31 @@
"""HTTP server for automated speech recognition (ASR)."""

import argparse
import io
import logging
import wave
from pathlib import Path

from flask import Flask, Response, jsonify, redirect, request
from swagger_ui import flask_api_doc # pylint: disable=no-name-in-module
from flask import Response, jsonify, request

from wyoming.asr import Transcribe, Transcript
from wyoming.audio import wav_to_chunks
from wyoming.client import AsyncClient
from wyoming.error import Error
from wyoming.info import Describe, Info

from .shared import get_app, get_argument_parser

_DIR = Path(__file__).parent
CONF_PATH = _DIR / "conf" / "asr.yaml"


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=5000)
parser.add_argument("--uri", help="URI of Wyoming ASR service")
parser = get_argument_parser()
parser.add_argument("--model", help="Default model name for transcription")
parser.add_argument("--language", help="Default language for transcription")
parser.add_argument("--samples-per-chunk", type=int, default=1024)
args = parser.parse_args()
logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO)

app = Flask("asr")

@app.route("/")
def redirect_to_api():
return redirect("/api")
app = get_app("asr", CONF_PATH, args)

@app.route("/api/speech-to-text", methods=["POST"])
async def api_stt() -> Response:
Expand All @@ -52,7 +45,7 @@ async def api_stt() -> Response:
with wave.open(wav_io, "rb") as wav_file:
chunks = wav_to_chunks(
wav_file,
samples_per_chunk=1024,
samples_per_chunk=args.samples_per_chunk,
start_event=True,
stop_event=True,
)
Expand All @@ -74,25 +67,6 @@ async def api_stt() -> Response:
f"Unexpected error from client: code={error.code}, text={error.text}"
)

@app.route("/api/info", methods=["GET"])
async def api_info():
uri = request.args.get("uri", args.uri)
if not uri:
raise ValueError("URI is required")

async with AsyncClient.from_uri(uri) as client:
await client.write_event(Describe().event())

while True:
event = await client.read_event()
if event is None:
raise RuntimeError("Client disconnected")

if Info.is_type(event.type):
info = Info.from_event(event)
return jsonify(info.to_dict())

flask_api_doc(app, config_path=str(CONF_PATH), url_prefix="/api", title="API doc")
app.run(args.host, args.port)


Expand Down
40 changes: 40 additions & 0 deletions wyoming/http/conf/wake.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
---
openapi: "3.0.0"
info:
title: 'Wyoming Wake'
version: '1.0.0'
description: 'API for Wake Word Detection'
paths:
/api/info:
get:
summary: 'Get service information'
responses:
'200':
description: OK
content:
application/json:
schema:
/api/detect-wake-word:
post:
summary: 'Transcribe WAV data to text'
requestBody:
description: 'WAV data (16-bit 16Khz mono preferred)'
required: true
content:
audio/wav:
schema:
type: string
format: binary
parameters:
- in: query
name: uri
description: 'URI of Wyoming ASR service'
schema:
type: string
responses:
'200':
description: OK
content:
application/json:
schema:
type: object
63 changes: 63 additions & 0 deletions wyoming/http/shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Shared code for HTTP servers."""
import argparse
from pathlib import Path
from typing import Union

from flask import Flask, jsonify, redirect, request
from swagger_ui import flask_api_doc # pylint: disable=no-name-in-module

from wyoming.client import AsyncClient
from wyoming.info import Describe, Info


def get_argument_parser() -> argparse.ArgumentParser:
"""Create argument parser with shared arguments."""
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=5000)
parser.add_argument("--uri", help="URI of Wyoming service")
parser.add_argument(
"--debug", action="store_true", help="Print DEBUG logs to console"
)
return parser


def get_app(
name: str, openapi_config_path: Union[str, Path], args: argparse.Namespace
) -> Flask:
"""Create Flask app with default endpoints."""

app = Flask(name)

@app.route("/")
def redirect_to_api():
return redirect("/api")

@app.route("/api/info", methods=["GET"])
async def api_info():
uri = request.args.get("uri", args.uri)
if not uri:
raise ValueError("URI is required")

async with AsyncClient.from_uri(uri) as client:
await client.write_event(Describe().event())

while True:
event = await client.read_event()
if event is None:
raise RuntimeError("Client disconnected")

if Info.is_type(event.type):
info = Info.from_event(event)
return jsonify(info.to_dict())

@app.errorhandler(Exception)
async def handle_error(err):
"""Return error as text."""
return (f"{err.__class__.__name__}: {err}", 500)

flask_api_doc(
app, config_path=str(openapi_config_path), url_prefix="/api", title="API doc"
)

return app
44 changes: 7 additions & 37 deletions wyoming/http/tts_server.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,31 @@
"""HTTP server for text to speech (TTS)."""
import argparse
import io
import logging
import wave
from pathlib import Path
from typing import Optional

from flask import Flask, Response, jsonify, redirect, request
from swagger_ui import flask_api_doc # pylint: disable=no-name-in-module
from flask import Response, request

from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.client import AsyncClient
from wyoming.error import Error
from wyoming.info import Describe, Info
from wyoming.tts import Synthesize, SynthesizeVoice

from .shared import get_app, get_argument_parser

_DIR = Path(__file__).parent
CONF_PATH = _DIR / "conf" / "tts.yaml"


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=5000)
parser.add_argument("--uri", help="URI of Wyoming ASR service")
parser = get_argument_parser()
parser.add_argument("--voice", help="Default voice for synthesis")
parser.add_argument("--speaker", help="Default voice speaker for synthesis")
args = parser.parse_args()
logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO)

app = Flask("tts")

@app.route("/")
def redirect_to_api():
return redirect("/api")
app = get_app("tts", CONF_PATH, args)

@app.route("/api/text-to-speech", methods=["POST", "GET"])
async def api_stt() -> Response:
Expand Down Expand Up @@ -84,30 +78,6 @@ async def api_stt() -> Response:
f"Unexpected error from client: code={error.code}, text={error.text}"
)

@app.route("/api/info", methods=["GET"])
async def api_info():
uri = request.args.get("uri", args.uri)
if not uri:
raise ValueError("URI is required")

async with AsyncClient.from_uri(uri) as client:
await client.write_event(Describe().event())

while True:
event = await client.read_event()
if event is None:
raise RuntimeError("Client disconnected")

if Info.is_type(event.type):
info = Info.from_event(event)
return jsonify(info.to_dict())

@app.errorhandler(Exception)
async def handle_error(err):
"""Return error as text."""
return (f"{err.__class__.__name__}: {err}", 500)

flask_api_doc(app, config_path=str(CONF_PATH), url_prefix="/api", title="API doc")
app.run(args.host, args.port)


Expand Down
64 changes: 64 additions & 0 deletions wyoming/http/wake_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""HTTP server for wake word detection."""
import io
import logging
import wave
from pathlib import Path

from flask import Response, jsonify, request

from wyoming.audio import wav_to_chunks
from wyoming.client import AsyncClient
from wyoming.error import Error
from wyoming.wake import Detection, NotDetected

from .shared import get_app, get_argument_parser

_DIR = Path(__file__).parent
CONF_PATH = _DIR / "conf" / "wake.yaml"


def main():
parser = get_argument_parser()
parser.add_argument("--samples-per-chunk", type=int, default=1024)
args = parser.parse_args()
logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO)

app = get_app("wake", CONF_PATH, args)

@app.route("/api/detect-wake-word", methods=["POST", "GET"])
async def api_wake() -> Response:
uri = request.args.get("uri", args.uri)
if not uri:
raise ValueError("URI is required")

async with AsyncClient.from_uri(uri) as client:
with io.BytesIO(request.data) as wav_io:
with wave.open(wav_io, "rb") as wav_file:
chunks = wav_to_chunks(
wav_file,
samples_per_chunk=args.samples_per_chunk,
start_event=True,
stop_event=True,
)
for chunk in chunks:
await client.write_event(chunk.event())

while True:
event = await client.read_event()
if event is None:
raise RuntimeError("Client disconnected")

if Detection.is_type(event.type) or NotDetected.is_type(event.type):
return jsonify(event.to_dict())

if Error.is_type(event.type):
error = Error.from_event(event)
raise RuntimeError(
f"Unexpected error from client: code={error.code}, text={error.text}"
)

app.run(args.host, args.port)


if __name__ == "__main__":
main()

0 comments on commit 1166e56

Please sign in to comment.