update
This commit is contained in:
90
schemas/context.py
Normal file
90
schemas/context.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from schemas.stream import StreamDataModel
|
||||
|
||||
|
||||
class LabelContext(BaseModel):
|
||||
start: float
|
||||
end: float
|
||||
answer: str
|
||||
|
||||
|
||||
class PredContext(BaseModel):
|
||||
recognition_results: StreamDataModel
|
||||
recv_time: Optional[float] = Field(None)
|
||||
send_time: Optional[float] = Field(None)
|
||||
|
||||
|
||||
class ASRContext:
|
||||
def __init__(self, **kwargs):
|
||||
self.bits = kwargs.get("bits", 16)
|
||||
self.channel = kwargs.get("channel", 1)
|
||||
self.sample_rate = kwargs.get("sample_rate", 16000)
|
||||
self.audio_format = kwargs.get("format", "wav")
|
||||
self.enable_words = kwargs.get("enable_words", True)
|
||||
self.char_contains_rate = kwargs.get("char_contains_rate", 0.8)
|
||||
self.lang = os.getenv("lang")
|
||||
if self.lang is None:
|
||||
self.lang = kwargs.get("lang", "en")
|
||||
self.stream = kwargs.get("stream", True)
|
||||
|
||||
self.wait_time = float(os.getenv("wait_time", 0.1))
|
||||
self.chunk_size = self.sample_rate * self.bits / 8 * self.wait_time
|
||||
if int(os.getenv('chunk_size_set', 0)):
|
||||
self.chunk_size = int(os.getenv('chunk_size_set', 0))
|
||||
|
||||
self.audio_length = 0
|
||||
self.file_path = ""
|
||||
|
||||
self.labels: List[LabelContext] = kwargs.get("labels", [])
|
||||
self.preds: List[PredContext] = kwargs.get("preds", [])
|
||||
|
||||
self.label_sentences: List[str] = []
|
||||
self.pred_sentences: List[str] = []
|
||||
|
||||
self.send_time_start_end = []
|
||||
self.recv_time_start_end = []
|
||||
|
||||
self.fail = False
|
||||
self.fail_char_contains_rate_num = 0
|
||||
|
||||
self.punctuation_num = 0
|
||||
self.pred_punctuation_num = 0
|
||||
|
||||
def append_labels(self, voices: List[Dict]):
|
||||
for voice_data in voices:
|
||||
label_context = LabelContext(**voice_data)
|
||||
self.labels.append(label_context)
|
||||
|
||||
def append_preds(
|
||||
self,
|
||||
predict_data: List[StreamDataModel],
|
||||
send_time: List[float],
|
||||
recv_time: List[float],
|
||||
):
|
||||
self.send_time_start_end = [send_time[0], send_time[-1]] if len(send_time) > 0 else []
|
||||
self.recv_time_start_end = [recv_time[0], recv_time[-1]] if len(recv_time) > 0 else []
|
||||
for pred_item, send_time_item, recv_time_item in zip(predict_data, send_time, recv_time):
|
||||
pred_item = deepcopy(pred_item)
|
||||
pred_context = PredContext(recognition_results=pred_item.model_dump())
|
||||
pred_context.send_time = send_time_item
|
||||
pred_context.recv_time = recv_time_item
|
||||
self.preds.append(pred_context)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"bits": self.bits,
|
||||
"channel": self.channel,
|
||||
"sample_rate": self.sample_rate,
|
||||
"audio_format": self.audio_format,
|
||||
"enable_words": self.enable_words,
|
||||
"stream": self.stream,
|
||||
"wait_time": self.wait_time,
|
||||
"chunk_size": self.chunk_size,
|
||||
"labels": [item.model_dump_json() for item in self.labels],
|
||||
"preds": [item.model_dump_json() for item in self.preds],
|
||||
}
|
||||
Reference in New Issue
Block a user