Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
2
vllm/entrypoints/openai/speech_to_text/__init__.py
Normal file
2
vllm/entrypoints/openai/speech_to_text/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
159
vllm/entrypoints/openai/speech_to_text/api_router.py
Normal file
159
vllm/entrypoints/openai/speech_to_text/api_router.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
|
||||
from fastapi import APIRouter, FastAPI, Form, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.speech_to_text.protocol import (
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponseVariant,
|
||||
TranslationRequest,
|
||||
TranslationResponseVariant,
|
||||
)
|
||||
from vllm.entrypoints.openai.speech_to_text.serving import (
|
||||
OpenAIServingTranscription,
|
||||
OpenAIServingTranslation,
|
||||
)
|
||||
from vllm.entrypoints.utils import (
|
||||
load_aware_call,
|
||||
with_cancellation,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
||||
from starlette.datastructures import State
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.tasks import SupportedTask
|
||||
else:
|
||||
RequestLogger = object
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def transcription(request: Request) -> OpenAIServingTranscription:
|
||||
return request.app.state.openai_serving_transcription
|
||||
|
||||
|
||||
def translation(request: Request) -> OpenAIServingTranslation:
|
||||
return request.app.state.openai_serving_translation
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/audio/transcriptions",
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_transcriptions(
|
||||
raw_request: Request, request: Annotated[TranscriptionRequest, Form()]
|
||||
):
|
||||
handler = transcription(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Transcriptions API"
|
||||
)
|
||||
|
||||
audio_data = await request.file.read()
|
||||
try:
|
||||
generator = await handler.create_transcription(audio_data, request, raw_request)
|
||||
except Exception as e:
|
||||
return handler.create_error_response(e)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
|
||||
elif isinstance(generator, TranscriptionResponseVariant):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/audio/translations",
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_translations(
|
||||
request: Annotated[TranslationRequest, Form()], raw_request: Request
|
||||
):
|
||||
handler = translation(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Translations API"
|
||||
)
|
||||
|
||||
audio_data = await request.file.read()
|
||||
try:
|
||||
generator = await handler.create_translation(audio_data, request, raw_request)
|
||||
except Exception as e:
|
||||
return handler.create_error_response(e)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
|
||||
elif isinstance(generator, TranslationResponseVariant):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
app.include_router(router)
|
||||
|
||||
|
||||
def init_transcription_state(
|
||||
engine_client: "EngineClient",
|
||||
state: "State",
|
||||
args: "Namespace",
|
||||
request_logger: RequestLogger | None,
|
||||
supported_tasks: tuple["SupportedTask", ...],
|
||||
):
|
||||
state.openai_serving_transcription = (
|
||||
OpenAIServingTranscription(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=args.log_error_stack,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
)
|
||||
if "transcription" in supported_tasks
|
||||
else None
|
||||
)
|
||||
state.openai_serving_translation = (
|
||||
OpenAIServingTranslation(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=args.log_error_stack,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
)
|
||||
if "transcription" in supported_tasks
|
||||
else None
|
||||
)
|
||||
545
vllm/entrypoints/openai/speech_to_text/protocol.py
Normal file
545
vllm/entrypoints/openai/speech_to_text/protocol.py
Normal file
@@ -0,0 +1,545 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
import torch
|
||||
from fastapi import HTTPException, UploadFile
|
||||
from pydantic import (
|
||||
Field,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaMessage,
|
||||
OpenAIBaseModel,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import (
|
||||
RequestOutputKind,
|
||||
SamplingParams,
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
_LONG_INFO = torch.iinfo(torch.long)
|
||||
|
||||
|
||||
class TranscriptionResponseStreamChoice(OpenAIBaseModel):
|
||||
delta: DeltaMessage
|
||||
finish_reason: str | None = None
|
||||
stop_reason: int | str | None = None
|
||||
|
||||
|
||||
class TranscriptionStreamResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}")
|
||||
object: Literal["transcription.chunk"] = "transcription.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: list[TranscriptionResponseStreamChoice]
|
||||
usage: UsageInfo | None = Field(default=None)
|
||||
|
||||
|
||||
## Protocols for Audio
|
||||
AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"]
|
||||
|
||||
|
||||
class TranscriptionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/audio/createTranscription
|
||||
|
||||
file: UploadFile
|
||||
"""
|
||||
The audio file object (not file name) to transcribe, in one of these
|
||||
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
|
||||
"""
|
||||
|
||||
model: str | None = None
|
||||
"""ID of the model to use.
|
||||
"""
|
||||
|
||||
language: str | None = None
|
||||
"""The language of the input audio.
|
||||
|
||||
Supplying the input language in
|
||||
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
|
||||
will improve accuracy and latency.
|
||||
"""
|
||||
|
||||
prompt: str = Field(default="")
|
||||
"""An optional text to guide the model's style or continue a previous audio
|
||||
segment.
|
||||
|
||||
The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
|
||||
should match the audio language.
|
||||
"""
|
||||
|
||||
response_format: AudioResponseFormat = Field(default="json")
|
||||
"""
|
||||
The format of the output, in one of these options: `json`, `text`, `srt`,
|
||||
`verbose_json`, or `vtt`.
|
||||
"""
|
||||
|
||||
## TODO (varun) : Support if set to 0, certain thresholds are met !!
|
||||
|
||||
timestamp_granularities: list[Literal["word", "segment"]] = Field(
|
||||
alias="timestamp_granularities[]", default=[]
|
||||
)
|
||||
"""The timestamp granularities to populate for this transcription.
|
||||
|
||||
`response_format` must be set `verbose_json` to use timestamp granularities.
|
||||
Either or both of these options are supported: `word`, or `segment`. Note:
|
||||
There is no additional latency for segment timestamps, but generating word
|
||||
timestamps incurs additional latency.
|
||||
"""
|
||||
|
||||
stream: bool | None = False
|
||||
"""When set, it will enable output to be streamed in a similar fashion
|
||||
as the Chat Completion endpoint.
|
||||
"""
|
||||
# --8<-- [start:transcription-extra-params]
|
||||
# Flattened stream option to simplify form data.
|
||||
stream_include_usage: bool | None = False
|
||||
stream_continuous_usage_stats: bool | None = False
|
||||
|
||||
vllm_xargs: dict[str, str | int | float] | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Additional request parameters with string or "
|
||||
"numeric values, used by custom extensions."
|
||||
),
|
||||
)
|
||||
# --8<-- [end:transcription-extra-params]
|
||||
|
||||
to_language: str | None = None
|
||||
"""The language of the output audio we transcribe to.
|
||||
|
||||
Please note that this is not currently used by supported models at this
|
||||
time, but it is a placeholder for future use, matching translation api.
|
||||
"""
|
||||
|
||||
# --8<-- [start:transcription-sampling-params]
|
||||
temperature: float = Field(default=0.0)
|
||||
"""The sampling temperature, between 0 and 1.
|
||||
|
||||
Higher values like 0.8 will make the output more random, while lower values
|
||||
like 0.2 will make it more focused / deterministic. If set to 0, the model
|
||||
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
|
||||
to automatically increase the temperature until certain thresholds are hit.
|
||||
"""
|
||||
|
||||
top_p: float | None = None
|
||||
"""Enables nucleus (top-p) sampling, where tokens are selected from the
|
||||
smallest possible set whose cumulative probability exceeds `p`.
|
||||
"""
|
||||
|
||||
top_k: int | None = None
|
||||
"""Limits sampling to the `k` most probable tokens at each step."""
|
||||
|
||||
min_p: float | None = None
|
||||
"""Filters out tokens with a probability lower than `min_p`, ensuring a
|
||||
minimum likelihood threshold during sampling.
|
||||
"""
|
||||
|
||||
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||
"""The seed to use for sampling."""
|
||||
|
||||
frequency_penalty: float | None = 0.0
|
||||
"""The frequency penalty to use for sampling."""
|
||||
|
||||
repetition_penalty: float | None = None
|
||||
"""The repetition penalty to use for sampling."""
|
||||
|
||||
presence_penalty: float | None = 0.0
|
||||
"""The presence penalty to use for sampling."""
|
||||
|
||||
max_completion_tokens: int | None = None
|
||||
"""The maximum number of tokens to generate."""
|
||||
# --8<-- [end:transcription-sampling-params]
|
||||
|
||||
# Default sampling parameters for transcription requests.
|
||||
_DEFAULT_SAMPLING_PARAMS: dict = {
|
||||
"repetition_penalty": 1.0,
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
"top_k": 0,
|
||||
"min_p": 0.0,
|
||||
}
|
||||
|
||||
def to_sampling_params(
|
||||
self, default_max_tokens: int, default_sampling_params: dict | None = None
|
||||
) -> SamplingParams:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
|
||||
# Default parameters
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get(
|
||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
|
||||
)
|
||||
if (top_p := self.top_p) is None:
|
||||
top_p = default_sampling_params.get(
|
||||
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
|
||||
)
|
||||
if (top_k := self.top_k) is None:
|
||||
top_k = default_sampling_params.get(
|
||||
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
|
||||
)
|
||||
if (min_p := self.min_p) is None:
|
||||
min_p = default_sampling_params.get(
|
||||
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
|
||||
)
|
||||
|
||||
if (repetition_penalty := self.repetition_penalty) is None:
|
||||
repetition_penalty = default_sampling_params.get(
|
||||
"repetition_penalty",
|
||||
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
|
||||
)
|
||||
|
||||
return SamplingParams.from_optional(
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
seed=self.seed,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
presence_penalty=self.presence_penalty,
|
||||
output_kind=RequestOutputKind.DELTA
|
||||
if self.stream
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
extra_args=self.vllm_xargs,
|
||||
skip_clone=True, # Created fresh per request, safe to skip clone
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_transcription_request(cls, data):
|
||||
if isinstance(data.get("file"), str):
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.UNPROCESSABLE_ENTITY,
|
||||
detail="Expected 'file' to be a file-like object, not 'str'.",
|
||||
)
|
||||
|
||||
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
|
||||
stream = data.get("stream", False)
|
||||
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
|
||||
# Find which specific stream option was set
|
||||
invalid_param = next(
|
||||
(so for so in stream_opts if data.get(so, False)),
|
||||
"stream_include_usage",
|
||||
)
|
||||
raise VLLMValidationError(
|
||||
"Stream options can only be defined when `stream=True`.",
|
||||
parameter=invalid_param,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# Transcription response objects
|
||||
class TranscriptionUsageAudio(OpenAIBaseModel):
|
||||
type: Literal["duration"] = "duration"
|
||||
seconds: int
|
||||
|
||||
|
||||
class TranscriptionResponse(OpenAIBaseModel):
|
||||
text: str
|
||||
"""The transcribed text."""
|
||||
usage: TranscriptionUsageAudio
|
||||
|
||||
|
||||
class TranscriptionWord(OpenAIBaseModel):
|
||||
end: float
|
||||
"""End time of the word in seconds."""
|
||||
|
||||
start: float
|
||||
"""Start time of the word in seconds."""
|
||||
|
||||
word: str
|
||||
"""The text content of the word."""
|
||||
|
||||
|
||||
class TranscriptionSegment(OpenAIBaseModel):
|
||||
id: int
|
||||
"""Unique identifier of the segment."""
|
||||
|
||||
avg_logprob: float
|
||||
"""Average logprob of the segment.
|
||||
|
||||
If the value is lower than -1, consider the logprobs failed.
|
||||
"""
|
||||
|
||||
compression_ratio: float
|
||||
"""Compression ratio of the segment.
|
||||
|
||||
If the value is greater than 2.4, consider the compression failed.
|
||||
"""
|
||||
|
||||
end: float
|
||||
"""End time of the segment in seconds."""
|
||||
|
||||
no_speech_prob: float | None = None
|
||||
"""Probability of no speech in the segment.
|
||||
|
||||
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
|
||||
this segment silent.
|
||||
"""
|
||||
|
||||
seek: int
|
||||
"""Seek offset of the segment."""
|
||||
|
||||
start: float
|
||||
"""Start time of the segment in seconds."""
|
||||
|
||||
temperature: float
|
||||
"""Temperature parameter used for generating the segment."""
|
||||
|
||||
text: str
|
||||
"""Text content of the segment."""
|
||||
|
||||
tokens: list[int]
|
||||
"""Array of token IDs for the text content."""
|
||||
|
||||
|
||||
class TranscriptionResponseVerbose(OpenAIBaseModel):
|
||||
duration: str
|
||||
"""The duration of the input audio."""
|
||||
|
||||
language: str
|
||||
"""The language of the input audio."""
|
||||
|
||||
text: str
|
||||
"""The transcribed text."""
|
||||
|
||||
segments: list[TranscriptionSegment] | None = None
|
||||
"""Segments of the transcribed text and their corresponding details."""
|
||||
|
||||
words: list[TranscriptionWord] | None = None
|
||||
"""Extracted words and their corresponding timestamps."""
|
||||
|
||||
|
||||
TranscriptionResponseVariant: TypeAlias = (
|
||||
TranscriptionResponse | TranscriptionResponseVerbose
|
||||
)
|
||||
|
||||
|
||||
class TranslationResponseStreamChoice(OpenAIBaseModel):
|
||||
delta: DeltaMessage
|
||||
finish_reason: str | None = None
|
||||
stop_reason: int | str | None = None
|
||||
|
||||
|
||||
class TranslationStreamResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}")
|
||||
object: Literal["translation.chunk"] = "translation.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: list[TranslationResponseStreamChoice]
|
||||
usage: UsageInfo | None = Field(default=None)
|
||||
|
||||
|
||||
class TranslationRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/audio/createTranslation
|
||||
|
||||
file: UploadFile
|
||||
"""
|
||||
The audio file object (not file name) to translate, in one of these
|
||||
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
|
||||
"""
|
||||
|
||||
model: str | None = None
|
||||
"""ID of the model to use.
|
||||
"""
|
||||
|
||||
prompt: str = Field(default="")
|
||||
"""An optional text to guide the model's style or continue a previous audio
|
||||
segment.
|
||||
|
||||
The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
|
||||
should match the audio language.
|
||||
"""
|
||||
|
||||
response_format: AudioResponseFormat = Field(default="json")
|
||||
"""
|
||||
The format of the output, in one of these options: `json`, `text`, `srt`,
|
||||
`verbose_json`, or `vtt`.
|
||||
"""
|
||||
|
||||
# TODO support additional sampling parameters
|
||||
# --8<-- [start:translation-sampling-params]
|
||||
seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||
"""The seed to use for sampling."""
|
||||
|
||||
temperature: float = Field(default=0.0)
|
||||
"""The sampling temperature, between 0 and 1.
|
||||
|
||||
Higher values like 0.8 will make the output more random, while lower values
|
||||
like 0.2 will make it more focused / deterministic. If set to 0, the model
|
||||
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
|
||||
to automatically increase the temperature until certain thresholds are hit.
|
||||
"""
|
||||
# --8<-- [end:translation-sampling-params]
|
||||
|
||||
# --8<-- [start:translation-extra-params]
|
||||
language: str | None = None
|
||||
"""The language of the input audio we translate from.
|
||||
|
||||
Supplying the input language in
|
||||
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
|
||||
will improve accuracy.
|
||||
"""
|
||||
|
||||
to_language: str | None = None
|
||||
"""The language of the input audio we translate to.
|
||||
|
||||
Please note that this is not supported by all models, refer to the specific
|
||||
model documentation for more details.
|
||||
For instance, Whisper only supports `to_language=en`.
|
||||
"""
|
||||
|
||||
stream: bool | None = False
|
||||
"""Custom field not present in the original OpenAI definition. When set,
|
||||
it will enable output to be streamed in a similar fashion as the Chat
|
||||
Completion endpoint.
|
||||
"""
|
||||
# Flattened stream option to simplify form data.
|
||||
stream_include_usage: bool | None = False
|
||||
stream_continuous_usage_stats: bool | None = False
|
||||
|
||||
max_completion_tokens: int | None = None
|
||||
"""The maximum number of tokens to generate."""
|
||||
# --8<-- [end:translation-extra-params]
|
||||
|
||||
# Default sampling parameters for translation requests.
|
||||
_DEFAULT_SAMPLING_PARAMS: dict = {
|
||||
"temperature": 0,
|
||||
}
|
||||
|
||||
def to_sampling_params(
|
||||
self, default_max_tokens: int, default_sampling_params: dict | None = None
|
||||
) -> SamplingParams:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
# Default parameters
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get(
|
||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]
|
||||
)
|
||||
|
||||
return SamplingParams.from_optional(
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
seed=self.seed,
|
||||
output_kind=RequestOutputKind.DELTA
|
||||
if self.stream
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
skip_clone=True, # Created fresh per request, safe to skip clone
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_stream_options(cls, data):
|
||||
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
|
||||
stream = data.get("stream", False)
|
||||
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
|
||||
# Find which specific stream option was set
|
||||
invalid_param = next(
|
||||
(so for so in stream_opts if data.get(so, False)),
|
||||
"stream_include_usage",
|
||||
)
|
||||
raise VLLMValidationError(
|
||||
"Stream options can only be defined when `stream=True`.",
|
||||
parameter=invalid_param,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# Translation response objects
|
||||
class TranslationResponse(OpenAIBaseModel):
|
||||
text: str
|
||||
"""The translated text."""
|
||||
|
||||
|
||||
class TranslationWord(OpenAIBaseModel):
|
||||
end: float
|
||||
"""End time of the word in seconds."""
|
||||
|
||||
start: float
|
||||
"""Start time of the word in seconds."""
|
||||
|
||||
word: str
|
||||
"""The text content of the word."""
|
||||
|
||||
|
||||
class TranslationSegment(OpenAIBaseModel):
|
||||
id: int
|
||||
"""Unique identifier of the segment."""
|
||||
|
||||
avg_logprob: float
|
||||
"""Average logprob of the segment.
|
||||
|
||||
If the value is lower than -1, consider the logprobs failed.
|
||||
"""
|
||||
|
||||
compression_ratio: float
|
||||
"""Compression ratio of the segment.
|
||||
|
||||
If the value is greater than 2.4, consider the compression failed.
|
||||
"""
|
||||
|
||||
end: float
|
||||
"""End time of the segment in seconds."""
|
||||
|
||||
no_speech_prob: float | None = None
|
||||
"""Probability of no speech in the segment.
|
||||
|
||||
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
|
||||
this segment silent.
|
||||
"""
|
||||
|
||||
seek: int
|
||||
"""Seek offset of the segment."""
|
||||
|
||||
start: float
|
||||
"""Start time of the segment in seconds."""
|
||||
|
||||
temperature: float
|
||||
"""Temperature parameter used for generating the segment."""
|
||||
|
||||
text: str
|
||||
"""Text content of the segment."""
|
||||
|
||||
tokens: list[int]
|
||||
"""Array of token IDs for the text content."""
|
||||
|
||||
|
||||
class TranslationResponseVerbose(OpenAIBaseModel):
|
||||
duration: str
|
||||
"""The duration of the input audio."""
|
||||
|
||||
language: str
|
||||
"""The language of the input audio."""
|
||||
|
||||
text: str
|
||||
"""The translated text."""
|
||||
|
||||
segments: list[TranslationSegment] | None = None
|
||||
"""Segments of the translated text and their corresponding details."""
|
||||
|
||||
words: list[TranslationWord] | None = None
|
||||
"""Extracted words and their corresponding timestamps."""
|
||||
|
||||
|
||||
TranslationResponseVariant: TypeAlias = TranslationResponse | TranslationResponseVerbose
|
||||
176
vllm/entrypoints/openai/speech_to_text/serving.py
Normal file
176
vllm/entrypoints/openai/speech_to_text/serving.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorResponse,
|
||||
RequestResponseMetadata,
|
||||
)
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.speech_to_text.protocol import (
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
TranscriptionResponseStreamChoice,
|
||||
TranscriptionResponseVerbose,
|
||||
TranscriptionStreamResponse,
|
||||
TranslationRequest,
|
||||
TranslationResponse,
|
||||
TranslationResponseStreamChoice,
|
||||
TranslationResponseVerbose,
|
||||
TranslationStreamResponse,
|
||||
)
|
||||
from vllm.entrypoints.openai.speech_to_text.speech_to_text import OpenAISpeechToText
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingTranscription(OpenAISpeechToText):
|
||||
"""Handles transcription requests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
enable_force_include_usage: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
task_type="transcribe",
|
||||
log_error_stack=log_error_stack,
|
||||
enable_force_include_usage=enable_force_include_usage,
|
||||
)
|
||||
|
||||
async def create_transcription(
|
||||
self,
|
||||
audio_data: bytes,
|
||||
request: TranscriptionRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> (
|
||||
TranscriptionResponse
|
||||
| TranscriptionResponseVerbose
|
||||
| AsyncGenerator[str, None]
|
||||
| ErrorResponse
|
||||
):
|
||||
"""Transcription API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/audio/createTranscription
|
||||
for the API specification. This API mimics the OpenAI transcription API.
|
||||
"""
|
||||
return await self._create_speech_to_text(
|
||||
audio_data=audio_data,
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
response_class=(
|
||||
TranscriptionResponseVerbose
|
||||
if request.response_format == "verbose_json"
|
||||
else TranscriptionResponse
|
||||
),
|
||||
stream_generator_method=self.transcription_stream_generator,
|
||||
)
|
||||
|
||||
async def transcription_stream_generator(
|
||||
self,
|
||||
request: TranscriptionRequest,
|
||||
result_generator: list[AsyncGenerator[RequestOutput, None]],
|
||||
request_id: str,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
generator = self._speech_to_text_stream_generator(
|
||||
request=request,
|
||||
list_result_generator=result_generator,
|
||||
request_id=request_id,
|
||||
request_metadata=request_metadata,
|
||||
audio_duration_s=audio_duration_s,
|
||||
chunk_object_type="transcription.chunk",
|
||||
response_stream_choice_class=TranscriptionResponseStreamChoice,
|
||||
stream_response_class=TranscriptionStreamResponse,
|
||||
)
|
||||
async for chunk in generator:
|
||||
yield chunk
|
||||
|
||||
|
||||
class OpenAIServingTranslation(OpenAISpeechToText):
|
||||
"""Handles translation requests."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
log_error_stack: bool = False,
|
||||
enable_force_include_usage: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
task_type="translate",
|
||||
log_error_stack=log_error_stack,
|
||||
enable_force_include_usage=enable_force_include_usage,
|
||||
)
|
||||
|
||||
async def create_translation(
|
||||
self,
|
||||
audio_data: bytes,
|
||||
request: TranslationRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> (
|
||||
TranslationResponse
|
||||
| TranslationResponseVerbose
|
||||
| AsyncGenerator[str, None]
|
||||
| ErrorResponse
|
||||
):
|
||||
"""Translation API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/audio/createTranslation
|
||||
for the API specification. This API mimics the OpenAI translation API.
|
||||
"""
|
||||
return await self._create_speech_to_text(
|
||||
audio_data=audio_data,
|
||||
request=request,
|
||||
raw_request=raw_request,
|
||||
response_class=(
|
||||
TranslationResponseVerbose
|
||||
if request.response_format == "verbose_json"
|
||||
else TranslationResponse
|
||||
),
|
||||
stream_generator_method=self.translation_stream_generator,
|
||||
)
|
||||
|
||||
async def translation_stream_generator(
|
||||
self,
|
||||
request: TranslationRequest,
|
||||
result_generator: list[AsyncGenerator[RequestOutput, None]],
|
||||
request_id: str,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
generator = self._speech_to_text_stream_generator(
|
||||
request=request,
|
||||
list_result_generator=result_generator,
|
||||
request_id=request_id,
|
||||
request_metadata=request_metadata,
|
||||
audio_duration_s=audio_duration_s,
|
||||
chunk_object_type="translation.chunk",
|
||||
response_stream_choice_class=TranslationResponseStreamChoice,
|
||||
stream_response_class=TranslationStreamResponse,
|
||||
)
|
||||
async for chunk in generator:
|
||||
yield chunk
|
||||
770
vllm/entrypoints/openai/speech_to_text/speech_to_text.py
Normal file
770
vllm/entrypoints/openai/speech_to_text/speech_to_text.py
Normal file
@@ -0,0 +1,770 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import io
|
||||
import math
|
||||
import time
|
||||
import zlib
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from functools import cached_property
|
||||
from typing import Final, Literal, TypeAlias, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaMessage,
|
||||
ErrorResponse,
|
||||
RequestResponseMetadata,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing, SpeechToTextRequest
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.speech_to_text.protocol import (
|
||||
TranscriptionResponse,
|
||||
TranscriptionResponseStreamChoice,
|
||||
TranscriptionResponseVerbose,
|
||||
TranscriptionSegment,
|
||||
TranscriptionStreamResponse,
|
||||
TranslationResponse,
|
||||
TranslationResponseStreamChoice,
|
||||
TranslationResponseVerbose,
|
||||
TranslationSegment,
|
||||
TranslationStreamResponse,
|
||||
)
|
||||
from vllm.entrypoints.utils import get_max_tokens
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs import ProcessorInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import FlatLogprobs, Logprob
|
||||
from vllm.model_executor.models import (
|
||||
SupportsTranscription,
|
||||
supports_transcription,
|
||||
)
|
||||
from vllm.multimodal.audio import split_audio
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.renderers.inputs import DictPrompt, EncoderDecoderDictPrompt
|
||||
from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt, parse_model_prompt
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
try:
|
||||
import librosa
|
||||
except ImportError:
|
||||
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||
|
||||
SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse
|
||||
SpeechToTextResponseVerbose: TypeAlias = (
|
||||
TranscriptionResponseVerbose | TranslationResponseVerbose
|
||||
)
|
||||
SpeechToTextSegment: TypeAlias = TranscriptionSegment | TranslationSegment
|
||||
T = TypeVar("T", bound=SpeechToTextResponse)
|
||||
V = TypeVar("V", bound=SpeechToTextResponseVerbose)
|
||||
S = TypeVar("S", bound=SpeechToTextSegment)
|
||||
|
||||
ResponseType: TypeAlias = (
|
||||
TranscriptionResponse
|
||||
| TranslationResponse
|
||||
| TranscriptionResponseVerbose
|
||||
| TranslationResponseVerbose
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAISpeechToText(OpenAIServing):
|
||||
"""Base class for speech-to-text operations like transcription and
|
||||
translation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
task_type: Literal["transcribe", "translate"] = "transcribe",
|
||||
log_error_stack: bool = False,
|
||||
enable_force_include_usage: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
|
||||
self.default_sampling_params = self.model_config.get_diff_sampling_param()
|
||||
self.task_type: Final = task_type
|
||||
|
||||
self.asr_config = self.model_cls.get_speech_to_text_config(
|
||||
self.model_config, task_type
|
||||
)
|
||||
|
||||
self.enable_force_include_usage = enable_force_include_usage
|
||||
|
||||
self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB
|
||||
if self.model_cls.supports_segment_timestamp:
|
||||
self.tokenizer = cast(
|
||||
PreTrainedTokenizerBase,
|
||||
get_tokenizer(
|
||||
tokenizer_name=self.model_config.tokenizer,
|
||||
tokenizer_mode=self.model_config.tokenizer_mode,
|
||||
),
|
||||
)
|
||||
|
||||
if self.default_sampling_params:
|
||||
logger.info(
|
||||
"Overwriting default completion sampling param with: %s",
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
# Warm up audio preprocessing to avoid first-request latency
|
||||
self._warmup_audio_preprocessing()
|
||||
# Warm up input processor with dummy audio
|
||||
self._warmup_input_processor()
|
||||
|
||||
def _warmup_audio_preprocessing(self) -> None:
|
||||
"""Warm up audio processing libraries to avoid first-request latency.
|
||||
|
||||
The first call to librosa functions (load, get_duration, mel-spectrogram)
|
||||
triggers JIT compilation and library initialization which can take ~7s.
|
||||
This method warms up these operations during server initialization.
|
||||
"""
|
||||
# Skip warmup if librosa is not installed (optional dependency)
|
||||
if isinstance(librosa, PlaceholderModule):
|
||||
return
|
||||
|
||||
# Skip warmup if model doesn't support transcription
|
||||
if not supports_transcription(self.model_cls):
|
||||
return
|
||||
|
||||
if getattr(self.model_cls, "skip_warmup_audio_preprocessing", False):
|
||||
return
|
||||
|
||||
try:
|
||||
warmup_start = time.perf_counter()
|
||||
logger.info("Warming up audio preprocessing libraries...")
|
||||
|
||||
# Create a minimal dummy audio (1 second of silence at target sample rate)
|
||||
dummy_audio = np.zeros(int(self.asr_config.sample_rate), dtype=np.float32)
|
||||
|
||||
# Warm up librosa.load by using librosa functions on the dummy data
|
||||
# This initializes FFTW, numba JIT, and other audio processing libraries
|
||||
_ = librosa.get_duration(y=dummy_audio, sr=self.asr_config.sample_rate)
|
||||
|
||||
# Warm up mel-spectrogram computation with model-specific parameters
|
||||
from vllm.transformers_utils.processor import cached_processor_from_config
|
||||
|
||||
processor = cached_processor_from_config(self.model_config)
|
||||
feature_extractor = None
|
||||
if hasattr(processor, "feature_extractor"):
|
||||
feature_extractor = processor.feature_extractor
|
||||
elif hasattr(processor, "audio_processor"):
|
||||
# For models like GraniteSpeech that use audio_processor
|
||||
audio_proc = processor.audio_processor
|
||||
if hasattr(audio_proc, "feature_extractor"):
|
||||
feature_extractor = audio_proc.feature_extractor
|
||||
# If audio_processor doesn't have feature_extractor,
|
||||
# skip mel-spectrogram warmup for these models
|
||||
|
||||
if feature_extractor is not None:
|
||||
_ = librosa.feature.melspectrogram(
|
||||
y=dummy_audio,
|
||||
sr=self.asr_config.sample_rate,
|
||||
n_mels=getattr(feature_extractor, "n_mels", 128),
|
||||
n_fft=getattr(feature_extractor, "n_fft", 400),
|
||||
hop_length=getattr(feature_extractor, "hop_length", 160),
|
||||
)
|
||||
|
||||
warmup_elapsed = time.perf_counter() - warmup_start
|
||||
logger.info("Audio preprocessing warmup completed in %.2fs", warmup_elapsed)
|
||||
except Exception:
|
||||
# Don't fail initialization if warmup fails - log exception and continue
|
||||
logger.exception(
|
||||
"Audio preprocessing warmup failed (non-fatal): %s. "
|
||||
"First request may experience higher latency.",
|
||||
)
|
||||
|
||||
def _warmup_input_processor(self) -> None:
|
||||
"""Warm up input processor with dummy audio to avoid first-request latency.
|
||||
|
||||
The first call to renderer.render_cmpl() with multimodal audio
|
||||
triggers multimodal processing initialization which can take ~2.5s.
|
||||
This method processes a dummy audio request to warm up the pipeline.
|
||||
"""
|
||||
# Skip warmup if model doesn't support transcription
|
||||
if not supports_transcription(self.model_cls):
|
||||
return
|
||||
|
||||
# Only warm up if model supports transcription methods
|
||||
if not hasattr(self.model_cls, "get_generation_prompt"):
|
||||
return
|
||||
|
||||
try:
|
||||
warmup_start = time.perf_counter()
|
||||
logger.info("Warming up multimodal input processor...")
|
||||
|
||||
# Create minimal dummy audio (1 second of silence)
|
||||
dummy_audio = np.zeros(int(self.asr_config.sample_rate), dtype=np.float32)
|
||||
|
||||
# Use the same method that _preprocess_speech_to_text uses
|
||||
# to create the prompt
|
||||
dummy_prompt = self.model_cls.get_generation_prompt(
|
||||
audio=dummy_audio,
|
||||
stt_config=self.asr_config,
|
||||
model_config=self.model_config,
|
||||
language="en",
|
||||
task_type=self.task_type,
|
||||
request_prompt="",
|
||||
to_language=None,
|
||||
)
|
||||
parsed_prompt = parse_model_prompt(self.model_config, dummy_prompt)
|
||||
|
||||
# Process the dummy input through the input processor
|
||||
# This will trigger all the multimodal processing initialization
|
||||
_ = self.renderer.render_cmpl([parsed_prompt])
|
||||
|
||||
warmup_elapsed = time.perf_counter() - warmup_start
|
||||
logger.info("Input processor warmup completed in %.2fs", warmup_elapsed)
|
||||
except Exception:
|
||||
# Don't fail initialization if warmup fails - log warning and continue
|
||||
logger.exception(
|
||||
"Input processor warmup failed (non-fatal): %s. "
|
||||
"First request may experience higher latency."
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def model_cls(self) -> type[SupportsTranscription]:
|
||||
from vllm.model_executor.model_loader import get_model_cls
|
||||
|
||||
model_cls = get_model_cls(self.model_config)
|
||||
return cast(type[SupportsTranscription], model_cls)
|
||||
|
||||
async def _detect_language(
|
||||
self,
|
||||
audio_chunk: np.ndarray,
|
||||
request_id: str,
|
||||
) -> str:
|
||||
"""Auto-detect the spoken language from an audio chunk.
|
||||
|
||||
Delegates prompt construction and output parsing to the model class
|
||||
via ``get_language_detection_prompt`` and
|
||||
``parse_language_detection_output``.
|
||||
"""
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
prompt = self.model_cls.get_language_detection_prompt(
|
||||
audio_chunk,
|
||||
self.asr_config,
|
||||
)
|
||||
allowed_token_ids = self.model_cls.get_language_token_ids(
|
||||
self.tokenizer,
|
||||
)
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=1,
|
||||
temperature=0.0,
|
||||
allowed_token_ids=allowed_token_ids,
|
||||
)
|
||||
|
||||
result_generator = self.engine_client.generate(
|
||||
prompt,
|
||||
sampling_params,
|
||||
request_id,
|
||||
)
|
||||
|
||||
final_output: RequestOutput
|
||||
async for final_output in result_generator:
|
||||
if final_output.finished:
|
||||
break
|
||||
|
||||
token_ids = list(final_output.outputs[0].token_ids)
|
||||
lang = self.model_cls.parse_language_detection_output(
|
||||
token_ids,
|
||||
self.tokenizer,
|
||||
)
|
||||
|
||||
logger.info("Auto-detected language: '%s'", lang)
|
||||
return lang
|
||||
|
||||
async def _preprocess_speech_to_text(
|
||||
self,
|
||||
request: SpeechToTextRequest,
|
||||
audio_data: bytes,
|
||||
request_id: str,
|
||||
) -> tuple[list[ProcessorInputs], float]:
|
||||
# Validate request
|
||||
language = self.model_cls.validate_language(request.language)
|
||||
# Skip to_language validation to avoid extra logging for Whisper.
|
||||
to_language = (
|
||||
self.model_cls.validate_language(request.to_language)
|
||||
if request.to_language
|
||||
else None
|
||||
)
|
||||
|
||||
if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
|
||||
raise VLLMValidationError(
|
||||
"Maximum file size exceeded",
|
||||
parameter="audio_filesize_mb",
|
||||
value=len(audio_data) / 1024**2,
|
||||
)
|
||||
|
||||
with io.BytesIO(audio_data) as bytes_:
|
||||
# NOTE resample to model SR here for efficiency. This is also a
|
||||
# pre-requisite for chunking, as it assumes Whisper SR.
|
||||
y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
|
||||
|
||||
duration = librosa.get_duration(y=y, sr=sr)
|
||||
do_split_audio = (
|
||||
self.asr_config.allow_audio_chunking
|
||||
and duration > self.asr_config.max_audio_clip_s
|
||||
)
|
||||
|
||||
if not do_split_audio:
|
||||
chunks = [y]
|
||||
else:
|
||||
assert self.asr_config.max_audio_clip_s is not None
|
||||
assert self.asr_config.min_energy_split_window_size is not None
|
||||
chunks = split_audio(
|
||||
audio_data=y,
|
||||
sample_rate=int(sr),
|
||||
max_clip_duration_s=self.asr_config.max_audio_clip_s,
|
||||
overlap_duration_s=self.asr_config.overlap_chunk_second,
|
||||
min_energy_window_size=self.asr_config.min_energy_split_window_size,
|
||||
)
|
||||
|
||||
if language is None and getattr(
|
||||
self.model_cls, "supports_explicit_language_detection", False
|
||||
):
|
||||
# Auto-detect language from the first chunk.
|
||||
language = await self._detect_language(
|
||||
chunks[0], f"{request_id}-lang_detect"
|
||||
)
|
||||
request.language = language
|
||||
|
||||
parsed_prompts: list[DictPrompt] = []
|
||||
for chunk in chunks:
|
||||
# The model has control over the construction, as long as it
|
||||
# returns a valid PromptType.
|
||||
prompt = self.model_cls.get_generation_prompt(
|
||||
audio=chunk,
|
||||
stt_config=self.asr_config,
|
||||
model_config=self.model_config,
|
||||
language=language,
|
||||
task_type=self.task_type,
|
||||
request_prompt=request.prompt,
|
||||
to_language=to_language,
|
||||
)
|
||||
|
||||
parsed_prompt: DictPrompt
|
||||
if request.response_format == "verbose_json":
|
||||
parsed_prompt = parse_enc_dec_prompt(prompt)
|
||||
parsed_prompt = self._preprocess_verbose_prompt(parsed_prompt)
|
||||
else:
|
||||
parsed_prompt = parse_model_prompt(self.model_config, prompt)
|
||||
|
||||
parsed_prompts.append(parsed_prompt)
|
||||
|
||||
engine_prompts = await self.renderer.render_cmpl_async(parsed_prompts)
|
||||
|
||||
return engine_prompts, duration
|
||||
|
||||
def _preprocess_verbose_prompt(self, prompt: EncoderDecoderDictPrompt):
|
||||
dec_prompt = prompt["decoder_prompt"]
|
||||
|
||||
if not (isinstance(dec_prompt, dict) and "prompt" in dec_prompt):
|
||||
raise VLLMValidationError(
|
||||
"Expected decoder_prompt to contain text",
|
||||
parameter="decoder_prompt",
|
||||
value=type(dec_prompt).__name__,
|
||||
)
|
||||
|
||||
dec_prompt["prompt"] = dec_prompt["prompt"].replace(
|
||||
"<|notimestamps|>", "<|0.00|>"
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
def _get_verbose_segments(
|
||||
self,
|
||||
tokens: tuple,
|
||||
log_probs: FlatLogprobs | list[dict[int, Logprob]],
|
||||
request: SpeechToTextRequest,
|
||||
segment_class: type[SpeechToTextSegment],
|
||||
start_time: float = 0,
|
||||
) -> list[SpeechToTextSegment]:
|
||||
"""
|
||||
Convert tokens to verbose segments.
|
||||
|
||||
This method expects the model to produce
|
||||
timestamps as tokens (similar to Whisper).
|
||||
If the tokens do not include timestamp information,
|
||||
the segments may not be generated correctly.
|
||||
|
||||
Note: No_speech_prob field is not supported
|
||||
in this implementation and will be None. See docs for details.
|
||||
"""
|
||||
BASE_OFFSET = 0.02
|
||||
init_token = self.tokenizer.encode("<|0.00|>", add_special_tokens=False)[0]
|
||||
if tokens[-1] == self.tokenizer.eos_token_id:
|
||||
tokens = tokens[:-1]
|
||||
|
||||
tokens_with_start = (init_token,) + tokens
|
||||
segments: list[SpeechToTextSegment] = []
|
||||
last_timestamp_start = 0
|
||||
|
||||
if tokens_with_start[-2] < init_token and tokens_with_start[-1] >= init_token:
|
||||
tokens_with_start = tokens_with_start + (tokens_with_start[-1],)
|
||||
avg_logprob = 0.0
|
||||
for idx in range(1, len(tokens_with_start)):
|
||||
# Timestamp tokens (e.g., <|0.00|>) are assumed to be sorted.
|
||||
# If the ordering is violated, this slicing may produce incorrect results.
|
||||
token = tokens_with_start[idx]
|
||||
if token >= init_token and tokens_with_start[idx - 1] >= init_token:
|
||||
sliced_timestamp_tokens = tokens_with_start[last_timestamp_start:idx]
|
||||
start_timestamp = sliced_timestamp_tokens[0] - init_token
|
||||
end_timestamp = sliced_timestamp_tokens[-1] - init_token
|
||||
text = self.tokenizer.decode(sliced_timestamp_tokens[1:-1])
|
||||
text_bytes = text.encode("utf-8")
|
||||
|
||||
casting_segment = cast(
|
||||
SpeechToTextSegment,
|
||||
segment_class(
|
||||
id=len(segments),
|
||||
seek=start_time,
|
||||
start=start_time + BASE_OFFSET * start_timestamp,
|
||||
end=start_time + BASE_OFFSET * end_timestamp,
|
||||
temperature=request.temperature,
|
||||
text=text,
|
||||
# The compression ratio measures
|
||||
# how compressible the generated text is.
|
||||
# A higher ratio indicates more repetitive content,
|
||||
# which is a strong sign of hallucination in outputs.
|
||||
compression_ratio=len(text_bytes)
|
||||
/ len(zlib.compress(text_bytes)),
|
||||
tokens=sliced_timestamp_tokens[1:-1],
|
||||
avg_logprob=avg_logprob / (idx - last_timestamp_start),
|
||||
),
|
||||
)
|
||||
segments.append(casting_segment)
|
||||
last_timestamp_start = idx
|
||||
avg_logprob = 0
|
||||
else:
|
||||
avg_logprob += log_probs[idx - 1][token].logprob
|
||||
return segments
|
||||
|
||||
async def _create_speech_to_text(
|
||||
self,
|
||||
audio_data: bytes,
|
||||
request: SpeechToTextRequest,
|
||||
raw_request: Request,
|
||||
response_class: type[ResponseType],
|
||||
stream_generator_method: Callable[..., AsyncGenerator[str, None]],
|
||||
) -> T | V | AsyncGenerator[str, None] | ErrorResponse:
|
||||
"""Base method for speech-to-text operations like transcription and
|
||||
translation."""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# If the engine is dead, raise the engine's DEAD_ERROR.
|
||||
# This is required for the streaming case, where we return a
|
||||
# success status before we actually start generating text :).
|
||||
if self.engine_client.errored:
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
if request.response_format not in ["text", "json", "verbose_json"]:
|
||||
return self.create_error_response(
|
||||
"Currently only support response_format: "
|
||||
"`text`, `json` or `verbose_json`"
|
||||
)
|
||||
|
||||
if (
|
||||
request.response_format == "verbose_json"
|
||||
and not self.model_cls.supports_segment_timestamp
|
||||
):
|
||||
return self.create_error_response(
|
||||
f"Currently do not support verbose_json for {request.model}"
|
||||
)
|
||||
|
||||
if request.response_format == "verbose_json" and request.stream:
|
||||
return self.create_error_response(
|
||||
"verbose_json format doesn't support streaming case"
|
||||
)
|
||||
request_id = f"{self.task_type}-{self._base_request_id(raw_request)}"
|
||||
|
||||
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
engine_prompts, duration_s = await self._preprocess_speech_to_text(
|
||||
request=request,
|
||||
audio_data=audio_data,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(e)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
max_model_len = self.model_config.max_model_len
|
||||
list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None
|
||||
try:
|
||||
# Unlike most decoder-only models, whisper generation length is not
|
||||
# constrained by the size of the input audio, which is mapped to a
|
||||
# fixed-size log-mel-spectogram. Still, allow for fewer tokens to be
|
||||
# generated by respecting the extra completion tokens arg.
|
||||
max_tokens = get_max_tokens(
|
||||
max_model_len,
|
||||
request.max_completion_tokens,
|
||||
0,
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
sampling_params = request.to_sampling_params(
|
||||
max_tokens,
|
||||
self.default_sampling_params,
|
||||
)
|
||||
if request.response_format == "verbose_json":
|
||||
sampling_params.logprobs = 1
|
||||
|
||||
list_result_generator = []
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}_{i}"
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
engine_prompt,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
trace_headers = (
|
||||
None
|
||||
if raw_request is None
|
||||
else await self._get_trace_headers(raw_request.headers)
|
||||
)
|
||||
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
|
||||
list_result_generator.append(generator)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(e)
|
||||
|
||||
if request.stream:
|
||||
return stream_generator_method(
|
||||
request, list_result_generator, request_id, request_metadata, duration_s
|
||||
)
|
||||
# Non-streaming response.
|
||||
total_segments = []
|
||||
text_parts = []
|
||||
try:
|
||||
assert list_result_generator is not None
|
||||
segments_types: dict[str, type[SpeechToTextSegment]] = {
|
||||
"transcribe": TranscriptionSegment,
|
||||
"translate": TranslationSegment,
|
||||
}
|
||||
segment_class: type[SpeechToTextSegment] = segments_types[self.task_type]
|
||||
text = ""
|
||||
chunk_size_in_s = self.asr_config.max_audio_clip_s
|
||||
if chunk_size_in_s is None:
|
||||
assert len(list_result_generator) == 1, (
|
||||
"`max_audio_clip_s` is set to None, audio cannot be chunked"
|
||||
)
|
||||
for idx, result_generator in enumerate(list_result_generator):
|
||||
start_time = (
|
||||
float(idx * chunk_size_in_s) if chunk_size_in_s is not None else 0.0
|
||||
)
|
||||
async for op in result_generator:
|
||||
if request.response_format == "verbose_json":
|
||||
assert op.outputs[0].logprobs
|
||||
segments: list[SpeechToTextSegment] = (
|
||||
self._get_verbose_segments(
|
||||
tokens=tuple(op.outputs[0].token_ids),
|
||||
segment_class=segment_class,
|
||||
request=request,
|
||||
start_time=start_time,
|
||||
log_probs=op.outputs[0].logprobs,
|
||||
)
|
||||
)
|
||||
|
||||
total_segments.extend(segments)
|
||||
text_parts.extend([seg.text for seg in segments])
|
||||
else:
|
||||
raw_text = op.outputs[0].text
|
||||
text_parts.append(self.model_cls.post_process_output(raw_text))
|
||||
text = "".join(text_parts)
|
||||
if self.task_type == "transcribe":
|
||||
final_response: ResponseType
|
||||
# add usage in TranscriptionResponse.
|
||||
usage = {
|
||||
"type": "duration",
|
||||
# rounded up as per openAI specs
|
||||
"seconds": int(math.ceil(duration_s)),
|
||||
}
|
||||
if request.response_format != "verbose_json":
|
||||
final_response = cast(
|
||||
T, TranscriptionResponse(text=text, usage=usage)
|
||||
)
|
||||
else:
|
||||
final_response = cast(
|
||||
V,
|
||||
TranscriptionResponseVerbose(
|
||||
text=text,
|
||||
language=request.language,
|
||||
duration=str(duration_s),
|
||||
segments=total_segments,
|
||||
),
|
||||
)
|
||||
else:
|
||||
# no usage in response for translation task
|
||||
if request.response_format != "verbose_json":
|
||||
final_response = cast(T, TranslationResponse(text=text))
|
||||
else:
|
||||
final_response = cast(
|
||||
V,
|
||||
TranslationResponseVerbose(
|
||||
text=text,
|
||||
language=request.language,
|
||||
duration=str(duration_s),
|
||||
segments=total_segments,
|
||||
),
|
||||
)
|
||||
return final_response
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
return self.create_error_response(e)
|
||||
|
||||
async def _speech_to_text_stream_generator(
|
||||
self,
|
||||
request: SpeechToTextRequest,
|
||||
list_result_generator: list[AsyncGenerator[RequestOutput, None]],
|
||||
request_id: str,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
audio_duration_s: float,
|
||||
chunk_object_type: Literal["translation.chunk", "transcription.chunk"],
|
||||
response_stream_choice_class: type[TranscriptionResponseStreamChoice]
|
||||
| type[TranslationResponseStreamChoice],
|
||||
stream_response_class: type[TranscriptionStreamResponse]
|
||||
| type[TranslationStreamResponse],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
created_time = int(time.time())
|
||||
model_name = request.model
|
||||
|
||||
completion_tokens = 0
|
||||
num_prompt_tokens = 0
|
||||
|
||||
include_usage = self.enable_force_include_usage or request.stream_include_usage
|
||||
include_continuous_usage = (
|
||||
request.stream_continuous_usage_stats
|
||||
if include_usage and request.stream_continuous_usage_stats
|
||||
else False
|
||||
)
|
||||
|
||||
try:
|
||||
for result_generator in list_result_generator:
|
||||
async for res in result_generator:
|
||||
# On first result.
|
||||
if res.prompt_token_ids is not None:
|
||||
num_prompt_tokens = len(res.prompt_token_ids)
|
||||
if audio_tokens := self.model_cls.get_num_audio_tokens(
|
||||
audio_duration_s, self.asr_config, self.model_config
|
||||
):
|
||||
num_prompt_tokens += audio_tokens
|
||||
|
||||
# We need to do it here, because if there are exceptions in
|
||||
# the result_generator, it needs to be sent as the FIRST
|
||||
# response (by the try...catch).
|
||||
|
||||
# Just one output (n=1) supported.
|
||||
assert len(res.outputs) == 1
|
||||
output = res.outputs[0]
|
||||
|
||||
# TODO: For models that output structured formats (e.g.,
|
||||
# Qwen3-ASR with "language X<asr_text>" prefix), streaming
|
||||
# would need buffering to strip the prefix properly since
|
||||
# deltas may split the tag across chunks.
|
||||
delta_message = DeltaMessage(content=output.text)
|
||||
completion_tokens += len(output.token_ids)
|
||||
|
||||
if output.finish_reason is None:
|
||||
# Still generating, send delta update.
|
||||
choice_data = response_stream_choice_class(delta=delta_message)
|
||||
else:
|
||||
# Model is finished generating.
|
||||
choice_data = response_stream_choice_class(
|
||||
delta=delta_message,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason,
|
||||
)
|
||||
|
||||
chunk = stream_response_class(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name,
|
||||
)
|
||||
|
||||
# handle usage stats if requested & if continuous
|
||||
if include_continuous_usage:
|
||||
chunk.usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Once the final token is handled, if stream_options.include_usage
|
||||
# is sent, send the usage.
|
||||
if include_usage:
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
final_usage_chunk = stream_response_class(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[],
|
||||
model=model_name,
|
||||
usage=final_usage,
|
||||
)
|
||||
final_usage_data = final_usage_chunk.model_dump_json(
|
||||
exclude_unset=True, exclude_none=True
|
||||
)
|
||||
yield f"data: {final_usage_data}\n\n"
|
||||
|
||||
# report to FastAPI middleware aggregate usage across all choices
|
||||
request_metadata.final_usage_info = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=num_prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error in %s stream generator.", self.task_type)
|
||||
data = self.create_streaming_error_response(e)
|
||||
yield f"data: {data}\n\n"
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
Reference in New Issue
Block a user