45 lines
1.4 KiB
Python
45 lines
1.4 KiB
Python
|
|
import subprocess
|
||
|
|
import sys
|
||
|
|
|
||
|
|
# Install kani-tts before importing
|
||
|
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "kani-tts"])
|
||
|
|
|
||
|
|
import io
|
||
|
|
import base64
|
||
|
|
from typing import Any, Dict, List, Union
|
||
|
|
import numpy as np
|
||
|
|
import soundfile as sf
|
||
|
|
from kani_tts import KaniTTS
|
||
|
|
|
||
|
|
|
||
|
|
class EndpointHandler:
|
||
|
|
def __init__(self, path: str = ""):
|
||
|
|
self.model = KaniTTS('jsbeaudry/haitian-kani-ht-v3')
|
||
|
|
self.sample_rate = 22050
|
||
|
|
|
||
|
|
def __call__(self, data: Dict[str, Any]) -> Any:
|
||
|
|
inputs = data.get("inputs", "")
|
||
|
|
parameters = data.get("parameters", {})
|
||
|
|
output_format = parameters.get("output_format", "base64")
|
||
|
|
sample_rate = parameters.get("sample_rate", self.sample_rate)
|
||
|
|
|
||
|
|
audio, text = self.model(f"3939afe3ea20 : {inputs}")
|
||
|
|
|
||
|
|
if not isinstance(audio, np.ndarray):
|
||
|
|
audio = np.array(audio)
|
||
|
|
|
||
|
|
audio_buffer = io.BytesIO()
|
||
|
|
sf.write(audio_buffer, audio, samplerate=sample_rate, format="WAV")
|
||
|
|
audio_bytes = audio_buffer.getvalue()
|
||
|
|
|
||
|
|
if output_format == "base64":
|
||
|
|
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
|
||
|
|
return [{
|
||
|
|
"audio": audio_b64,
|
||
|
|
"sample_rate": sample_rate,
|
||
|
|
"text": text,
|
||
|
|
"encoding": "base64",
|
||
|
|
"content_type": "audio/wav",
|
||
|
|
}]
|
||
|
|
|
||
|
|
return audio_bytes
|