91 lines
3.1 KiB
Python
91 lines
3.1 KiB
Python
import os
|
|
from copy import deepcopy
|
|
from typing import Dict, List, Optional
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from schemas.stream import StreamDataModel
|
|
|
|
|
|
class LabelContext(BaseModel):
|
|
start: float
|
|
end: float
|
|
answer: str
|
|
|
|
|
|
class PredContext(BaseModel):
|
|
recognition_results: StreamDataModel
|
|
recv_time: Optional[float] = Field(None)
|
|
send_time: Optional[float] = Field(None)
|
|
|
|
|
|
class ASRContext:
|
|
def __init__(self, **kwargs):
|
|
self.bits = kwargs.get("bits", 16)
|
|
self.channel = kwargs.get("channel", 1)
|
|
self.sample_rate = kwargs.get("sample_rate", 16000)
|
|
self.audio_format = kwargs.get("format", "wav")
|
|
self.enable_words = kwargs.get("enable_words", True)
|
|
self.char_contains_rate = kwargs.get("char_contains_rate", 0.8)
|
|
self.lang = os.getenv("lang")
|
|
if self.lang is None:
|
|
self.lang = kwargs.get("lang", "en")
|
|
self.stream = kwargs.get("stream", True)
|
|
|
|
self.wait_time = float(os.getenv("wait_time", 0.1))
|
|
self.chunk_size = self.sample_rate * self.bits / 8 * self.wait_time
|
|
if int(os.getenv('chunk_size_set', 0)):
|
|
self.chunk_size = int(os.getenv('chunk_size_set', 0))
|
|
|
|
self.audio_length = 0
|
|
self.file_path = ""
|
|
|
|
self.labels: List[LabelContext] = kwargs.get("labels", [])
|
|
self.preds: List[PredContext] = kwargs.get("preds", [])
|
|
|
|
self.label_sentences: List[str] = []
|
|
self.pred_sentences: List[str] = []
|
|
|
|
self.send_time_start_end = []
|
|
self.recv_time_start_end = []
|
|
|
|
self.fail = False
|
|
self.fail_char_contains_rate_num = 0
|
|
|
|
self.punctuation_num = 0
|
|
self.pred_punctuation_num = 0
|
|
|
|
def append_labels(self, voices: List[Dict]):
|
|
for voice_data in voices:
|
|
label_context = LabelContext(**voice_data)
|
|
self.labels.append(label_context)
|
|
|
|
def append_preds(
|
|
self,
|
|
predict_data: List[StreamDataModel],
|
|
send_time: List[float],
|
|
recv_time: List[float],
|
|
):
|
|
self.send_time_start_end = [send_time[0], send_time[-1]] if len(send_time) > 0 else []
|
|
self.recv_time_start_end = [recv_time[0], recv_time[-1]] if len(recv_time) > 0 else []
|
|
for pred_item, send_time_item, recv_time_item in zip(predict_data, send_time, recv_time):
|
|
pred_item = deepcopy(pred_item)
|
|
pred_context = PredContext(recognition_results=pred_item.model_dump())
|
|
pred_context.send_time = send_time_item
|
|
pred_context.recv_time = recv_time_item
|
|
self.preds.append(pred_context)
|
|
|
|
def to_dict(self):
|
|
return {
|
|
"bits": self.bits,
|
|
"channel": self.channel,
|
|
"sample_rate": self.sample_rate,
|
|
"audio_format": self.audio_format,
|
|
"enable_words": self.enable_words,
|
|
"stream": self.stream,
|
|
"wait_time": self.wait_time,
|
|
"chunk_size": self.chunk_size,
|
|
"labels": [item.model_dump_json() for item in self.labels],
|
|
"preds": [item.model_dump_json() for item in self.preds],
|
|
}
|