Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
2
vllm/entrypoints/openai/realtime/__init__.py
Normal file
2
vllm/entrypoints/openai/realtime/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
75
vllm/entrypoints/openai/realtime/api_router.py
Normal file
75
vllm/entrypoints/openai/realtime/api_router.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter, FastAPI, WebSocket
|
||||
|
||||
from vllm.entrypoints.openai.realtime.connection import RealtimeConnection
|
||||
from vllm.entrypoints.openai.realtime.serving import OpenAIServingRealtime
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
||||
from starlette.datastructures import State
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.tasks import SupportedTask
|
||||
else:
|
||||
RequestLogger = object
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.websocket("/v1/realtime")
|
||||
async def realtime_endpoint(websocket: WebSocket):
|
||||
"""WebSocket endpoint for realtime audio transcription.
|
||||
|
||||
Protocol:
|
||||
1. Client connects to ws://host/v1/realtime
|
||||
2. Server sends session.created event
|
||||
3. Client optionally sends session.update with model/params
|
||||
4. Client sends input_audio_buffer.commit when ready
|
||||
5. Client sends input_audio_buffer.append events with base64 PCM16 chunks
|
||||
6. Server processes and sends transcription.delta events
|
||||
7. Server sends transcription.done with final text + usage
|
||||
8. Repeat from step 5 for next utterance
|
||||
9. Optionally, client sends input_audio_buffer.commit with final=True
|
||||
to signal audio input is finished. Useful when streaming audio files
|
||||
|
||||
Audio format: PCM16, 16kHz, mono, base64-encoded
|
||||
"""
|
||||
app = websocket.app
|
||||
serving = app.state.openai_serving_realtime
|
||||
|
||||
connection = RealtimeConnection(websocket, serving)
|
||||
await connection.handle_connection()
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
"""Attach the realtime router to the FastAPI app."""
|
||||
app.include_router(router)
|
||||
logger.info("Realtime API router attached")
|
||||
|
||||
|
||||
def init_realtime_state(
|
||||
engine_client: "EngineClient",
|
||||
state: "State",
|
||||
args: "Namespace",
|
||||
request_logger: RequestLogger | None,
|
||||
supported_tasks: tuple["SupportedTask", ...],
|
||||
):
|
||||
state.openai_serving_realtime = (
|
||||
OpenAIServingRealtime(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
if "realtime" in supported_tasks
|
||||
else None
|
||||
)
|
||||
279
vllm/entrypoints/openai/realtime/connection.py
Normal file
279
vllm/entrypoints/openai/realtime/connection.py
Normal file
@@ -0,0 +1,279 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from http import HTTPStatus
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
from fastapi import WebSocket
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
from vllm import envs
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
|
||||
from vllm.entrypoints.openai.realtime.protocol import (
|
||||
ErrorEvent,
|
||||
InputAudioBufferAppend,
|
||||
InputAudioBufferCommit,
|
||||
SessionCreated,
|
||||
TranscriptionDelta,
|
||||
TranscriptionDone,
|
||||
)
|
||||
from vllm.entrypoints.openai.realtime.serving import OpenAIServingRealtime
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RealtimeConnection:
|
||||
"""Manages WebSocket lifecycle and state for realtime transcription.
|
||||
|
||||
This class handles:
|
||||
- WebSocket connection lifecycle (accept, receive, send, close)
|
||||
- Event routing (session.update, append, commit)
|
||||
- Audio buffering via asyncio.Queue
|
||||
- Generation task management
|
||||
- Error handling and cleanup
|
||||
"""
|
||||
|
||||
def __init__(self, websocket: WebSocket, serving: OpenAIServingRealtime):
|
||||
self.websocket = websocket
|
||||
self.connection_id = f"ws-{uuid4()}"
|
||||
self.serving = serving
|
||||
self.audio_queue: asyncio.Queue[np.ndarray | None] = asyncio.Queue()
|
||||
self.generation_task: asyncio.Task | None = None
|
||||
|
||||
self._is_connected = False
|
||||
self._is_model_validated = False
|
||||
|
||||
self._max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB
|
||||
|
||||
async def handle_connection(self):
|
||||
"""Main connection loop."""
|
||||
await self.websocket.accept()
|
||||
logger.debug("WebSocket connection accepted: %s", self.connection_id)
|
||||
self._is_connected = True
|
||||
|
||||
# Send session created event
|
||||
await self.send(SessionCreated())
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await self.websocket.receive_text()
|
||||
try:
|
||||
event = json.loads(message)
|
||||
await self.handle_event(event)
|
||||
except json.JSONDecodeError:
|
||||
await self.send_error("Invalid JSON", "invalid_json")
|
||||
except Exception as e:
|
||||
logger.exception("Error handling event: %s", e)
|
||||
await self.send_error(str(e), "processing_error")
|
||||
except WebSocketDisconnect:
|
||||
logger.debug("WebSocket disconnected: %s", self.connection_id)
|
||||
self._is_connected = False
|
||||
except Exception as e:
|
||||
logger.exception("Unexpected error in connection: %s", e)
|
||||
finally:
|
||||
await self.cleanup()
|
||||
|
||||
def _check_model(self, model: str | None) -> None | ErrorResponse:
|
||||
if self.serving._is_model_supported(model):
|
||||
return None
|
||||
|
||||
return self.serving.create_error_response(
|
||||
message=f"The model `{model}` does not exist.",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
param="model",
|
||||
)
|
||||
|
||||
async def handle_event(self, event: dict):
|
||||
"""Route events to handlers.
|
||||
|
||||
Supported event types:
|
||||
- session.update: Configure model
|
||||
- input_audio_buffer.append: Add audio chunk to queue
|
||||
- input_audio_buffer.commit: Start transcription generation
|
||||
"""
|
||||
event_type = event.get("type")
|
||||
if event_type == "session.update":
|
||||
logger.debug("Session updated: %s", event)
|
||||
self._check_model(event["model"])
|
||||
self._is_model_validated = True
|
||||
elif event_type == "input_audio_buffer.append":
|
||||
append_event = InputAudioBufferAppend(**event)
|
||||
try:
|
||||
audio_bytes = base64.b64decode(append_event.audio)
|
||||
# Convert PCM16 bytes to float32 numpy array
|
||||
audio_array = (
|
||||
np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32)
|
||||
/ 32768.0
|
||||
)
|
||||
|
||||
if len(audio_array) / 1024**2 > self._max_audio_filesize_mb:
|
||||
raise VLLMValidationError(
|
||||
"Maximum file size exceeded",
|
||||
parameter="audio_filesize_mb",
|
||||
value=len(audio_array) / 1024**2,
|
||||
)
|
||||
if len(audio_array) == 0:
|
||||
raise VLLMValidationError("Can't process empty audio.")
|
||||
|
||||
# Put audio chunk in queue
|
||||
self.audio_queue.put_nowait(audio_array)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to decode audio: %s", e)
|
||||
await self.send_error("Invalid audio data", "invalid_audio")
|
||||
|
||||
elif event_type == "input_audio_buffer.commit":
|
||||
if not self._is_model_validated:
|
||||
err_msg = (
|
||||
"Model not validated. Make sure to validate the"
|
||||
" model by sending a session.update event."
|
||||
)
|
||||
await self.send_error(
|
||||
err_msg,
|
||||
"model_not_validated",
|
||||
)
|
||||
|
||||
commit_event = InputAudioBufferCommit(**event)
|
||||
# final signals that the audio is finished
|
||||
if commit_event.final:
|
||||
self.audio_queue.put_nowait(None)
|
||||
else:
|
||||
await self.start_generation()
|
||||
else:
|
||||
await self.send_error(f"Unknown event type: {event_type}", "unknown_event")
|
||||
|
||||
async def audio_stream_generator(self) -> AsyncGenerator[np.ndarray, None]:
|
||||
"""Generator that yields audio chunks from the queue."""
|
||||
while True:
|
||||
audio_chunk = await self.audio_queue.get()
|
||||
if audio_chunk is None: # Sentinel value to stop
|
||||
break
|
||||
yield audio_chunk
|
||||
|
||||
async def start_generation(self):
|
||||
"""Start the transcription generation task."""
|
||||
if self.generation_task is not None and not self.generation_task.done():
|
||||
logger.warning("Generation already in progress, ignoring commit")
|
||||
return
|
||||
|
||||
# Create audio stream generator
|
||||
audio_stream = self.audio_stream_generator()
|
||||
input_stream = asyncio.Queue[list[int]]()
|
||||
|
||||
# Transform to StreamingInput generator
|
||||
streaming_input_gen = self.serving.transcribe_realtime(
|
||||
audio_stream, input_stream
|
||||
)
|
||||
|
||||
# Start generation task
|
||||
self.generation_task = asyncio.create_task(
|
||||
self._run_generation(streaming_input_gen, input_stream)
|
||||
)
|
||||
|
||||
async def _run_generation(
|
||||
self,
|
||||
streaming_input_gen: AsyncGenerator,
|
||||
input_stream: asyncio.Queue[list[int]],
|
||||
):
|
||||
"""Run the generation and stream results back to the client.
|
||||
|
||||
This method:
|
||||
1. Creates sampling parameters from session config
|
||||
2. Passes the streaming input generator to engine.generate()
|
||||
3. Streams transcription.delta events as text is generated
|
||||
4. Sends final transcription.done event with usage stats
|
||||
5. Feeds generated token IDs back to input_stream for next iteration
|
||||
6. Cleans up the audio queue
|
||||
"""
|
||||
request_id = f"rt-{self.connection_id}-{uuid4()}"
|
||||
full_text = ""
|
||||
|
||||
prompt_token_ids_len: int = 0
|
||||
completion_tokens_len: int = 0
|
||||
|
||||
try:
|
||||
# Create sampling params
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
|
||||
sampling_params = SamplingParams.from_optional(
|
||||
temperature=0.0,
|
||||
max_tokens=self.serving.model_cls.realtime_max_tokens,
|
||||
output_kind=RequestOutputKind.DELTA,
|
||||
skip_clone=True,
|
||||
)
|
||||
|
||||
# Pass the streaming input generator to the engine
|
||||
# The engine will consume audio chunks as they arrive and
|
||||
# stream back transcription results incrementally
|
||||
result_gen = self.serving.engine_client.generate(
|
||||
prompt=streaming_input_gen,
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
# Stream results back to client as they're generated
|
||||
async for output in result_gen:
|
||||
if output.outputs and len(output.outputs) > 0:
|
||||
if not prompt_token_ids_len and output.prompt_token_ids:
|
||||
prompt_token_ids_len = len(output.prompt_token_ids)
|
||||
|
||||
delta = output.outputs[0].text
|
||||
full_text += delta
|
||||
|
||||
# append output to input
|
||||
input_stream.put_nowait(list(output.outputs[0].token_ids))
|
||||
await self.send(TranscriptionDelta(delta=delta))
|
||||
|
||||
completion_tokens_len += len(output.outputs[0].token_ids)
|
||||
|
||||
if not self._is_connected:
|
||||
# finish because websocket connection was killed
|
||||
break
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=prompt_token_ids_len,
|
||||
completion_tokens=completion_tokens_len,
|
||||
total_tokens=prompt_token_ids_len + completion_tokens_len,
|
||||
)
|
||||
|
||||
# Send final completion event
|
||||
await self.send(TranscriptionDone(text=full_text, usage=usage))
|
||||
|
||||
# Clear queue for next utterance
|
||||
while not self.audio_queue.empty():
|
||||
self.audio_queue.get_nowait()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error in generation: %s", e)
|
||||
await self.send_error(str(e), "processing_error")
|
||||
|
||||
async def send(
|
||||
self, event: SessionCreated | TranscriptionDelta | TranscriptionDone
|
||||
):
|
||||
"""Send event to client."""
|
||||
data = event.model_dump_json()
|
||||
await self.websocket.send_text(data)
|
||||
|
||||
async def send_error(self, message: str, code: str | None = None):
|
||||
"""Send error event to client."""
|
||||
error_event = ErrorEvent(error=message, code=code)
|
||||
await self.websocket.send_text(error_event.model_dump_json())
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup resources."""
|
||||
# Signal audio stream to stop
|
||||
self.audio_queue.put_nowait(None)
|
||||
|
||||
# Cancel generation task if running
|
||||
if self.generation_task and not self.generation_task.done():
|
||||
self.generation_task.cancel()
|
||||
|
||||
logger.debug("Connection cleanup complete: %s", self.connection_id)
|
||||
68
vllm/entrypoints/openai/realtime/protocol.py
Normal file
68
vllm/entrypoints/openai/realtime/protocol.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import time
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
OpenAIBaseModel,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
# Client -> Server Events
|
||||
|
||||
|
||||
class InputAudioBufferAppend(OpenAIBaseModel):
|
||||
"""Append audio chunk to buffer"""
|
||||
|
||||
type: Literal["input_audio_buffer.append"] = "input_audio_buffer.append"
|
||||
audio: str # base64-encoded PCM16 @ 16kHz
|
||||
|
||||
|
||||
class InputAudioBufferCommit(OpenAIBaseModel):
|
||||
"""Process accumulated audio buffer"""
|
||||
|
||||
type: Literal["input_audio_buffer.commit"] = "input_audio_buffer.commit"
|
||||
final: bool = False
|
||||
|
||||
|
||||
# Server -> Client Events
|
||||
class SessionUpdate(OpenAIBaseModel):
|
||||
"""Configure session parameters"""
|
||||
|
||||
type: Literal["session.update"] = "session.update"
|
||||
model: str | None = None
|
||||
|
||||
|
||||
class SessionCreated(OpenAIBaseModel):
|
||||
"""Connection established notification"""
|
||||
|
||||
type: Literal["session.created"] = "session.created"
|
||||
id: str = Field(default_factory=lambda: f"sess-{random_uuid()}")
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
|
||||
|
||||
class TranscriptionDelta(OpenAIBaseModel):
|
||||
"""Incremental transcription text"""
|
||||
|
||||
type: Literal["transcription.delta"] = "transcription.delta"
|
||||
delta: str # Incremental text
|
||||
|
||||
|
||||
class TranscriptionDone(OpenAIBaseModel):
|
||||
"""Final transcription with usage stats"""
|
||||
|
||||
type: Literal["transcription.done"] = "transcription.done"
|
||||
text: str # Complete transcription
|
||||
usage: UsageInfo | None = None
|
||||
|
||||
|
||||
class ErrorEvent(OpenAIBaseModel):
|
||||
"""Error notification"""
|
||||
|
||||
type: Literal["error"] = "error"
|
||||
error: str
|
||||
code: str | None = None
|
||||
90
vllm/entrypoints/openai/realtime/serving.py
Normal file
90
vllm/entrypoints/openai/realtime/serving.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from functools import cached_property
|
||||
from typing import Literal, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
from vllm.engine.protocol import EngineClient, StreamingInput
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.engine.serving import OpenAIServing
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import SupportsRealtime
|
||||
from vllm.renderers.inputs.preprocess import parse_model_prompt
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingRealtime(OpenAIServing):
|
||||
"""Realtime audio transcription service via WebSocket streaming.
|
||||
|
||||
Provides streaming audio-to-text transcription by transforming audio chunks
|
||||
into StreamingInput objects that can be consumed by the engine.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
log_error_stack: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack,
|
||||
)
|
||||
|
||||
self.task_type: Literal["realtime"] = "realtime"
|
||||
|
||||
logger.info("OpenAIServingRealtime initialized for task: %s", self.task_type)
|
||||
|
||||
@cached_property
|
||||
def model_cls(self) -> type[SupportsRealtime]:
|
||||
"""Get the model class that supports transcription."""
|
||||
from vllm.model_executor.model_loader import get_model_cls
|
||||
|
||||
model_cls = get_model_cls(self.model_config)
|
||||
return cast(type[SupportsRealtime], model_cls)
|
||||
|
||||
async def transcribe_realtime(
|
||||
self,
|
||||
audio_stream: AsyncGenerator[np.ndarray, None],
|
||||
input_stream: asyncio.Queue[list[int]],
|
||||
) -> AsyncGenerator[StreamingInput, None]:
|
||||
"""Transform audio stream into StreamingInput for engine.generate().
|
||||
|
||||
Args:
|
||||
audio_stream: Async generator yielding float32 numpy audio arrays
|
||||
input_stream: Queue containing context token IDs from previous
|
||||
generation outputs. Used for autoregressive multi-turn
|
||||
processing where each generation's output becomes the context
|
||||
for the next iteration.
|
||||
|
||||
Yields:
|
||||
StreamingInput objects containing audio prompts for the engine
|
||||
"""
|
||||
model_config = self.model_config
|
||||
renderer = self.renderer
|
||||
|
||||
# mypy is being stupid
|
||||
# TODO(Patrick) - fix this
|
||||
stream_input_iter = cast(
|
||||
AsyncGenerator[PromptType, None],
|
||||
self.model_cls.buffer_realtime_audio(
|
||||
audio_stream, input_stream, model_config
|
||||
),
|
||||
)
|
||||
|
||||
async for prompt in stream_input_iter:
|
||||
parsed_prompt = parse_model_prompt(model_config, prompt)
|
||||
(engine_prompt,) = await renderer.render_cmpl_async([parsed_prompt])
|
||||
|
||||
yield StreamingInput(prompt=engine_prompt)
|
||||
Reference in New Issue
Block a user