#!/usr/bin/env python3 # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # mypy: ignore-errors """ vLLM gRPC Server Starts a gRPC server for vLLM using the VllmEngine protocol. Usage: python -m vllm.entrypoints.grpc_server --model Example: python -m vllm.entrypoints.grpc_server \ --model meta-llama/Llama-2-7b-hf \ --host 0.0.0.0 \ --port 50051 """ import argparse import asyncio import signal import sys import time from collections.abc import AsyncGenerator import grpc import uvloop from grpc_reflection.v1alpha import reflection from vllm import SamplingParams, TextPrompt, TokensPrompt from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.utils import log_version_and_model from vllm.grpc import vllm_engine_pb2, vllm_engine_pb2_grpc from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import RequestOutputKind, StructuredOutputsParams from vllm.usage.usage_lib import UsageContext from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.v1.engine.async_llm import AsyncLLM from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) class VllmEngineServicer(vllm_engine_pb2_grpc.VllmEngineServicer): """ gRPC servicer implementing the VllmEngine service. Handles 6 RPCs: - Generate: Streaming text generation - Embed: Embeddings (TODO) - HealthCheck: Health probe - Abort: Cancel requests out-of-band - GetModelInfo: Model metadata - GetServerInfo: Server state """ def __init__(self, async_llm: AsyncLLM, start_time: float): """ Initialize the servicer. Args: async_llm: The AsyncLLM instance start_time: The server start time, in seconds since epoch """ self.async_llm = async_llm self.start_time = start_time logger.info("VllmEngineServicer initialized") async def Generate( self, request: vllm_engine_pb2.GenerateRequest, context: grpc.aio.ServicerContext, ) -> AsyncGenerator[vllm_engine_pb2.GenerateResponse, None]: """ Handle streaming generation requests. Args: request: The GenerateRequest protobuf context: gRPC context Yields: GenerateResponse protobuf messages (streaming) """ request_id = request.request_id logger.debug("Generate request %s received.", request_id) try: # Extract tokenized input if request.WhichOneof("input") == "tokenized": prompt: TokensPrompt = { "prompt_token_ids": list(request.tokenized.input_ids) } if request.tokenized.original_text: prompt["prompt"] = request.tokenized.original_text else: prompt: TextPrompt = {"prompt": request.text} # Build sampling params with detokenize=False sampling_params = self._sampling_params_from_proto( request.sampling_params, stream=request.stream ) tokenization_kwargs = self._tokenization_kwargs_from_proto( request.sampling_params ) async for output in self.async_llm.generate( prompt=prompt, sampling_params=sampling_params, request_id=request_id, tokenization_kwargs=tokenization_kwargs, ): # Convert vLLM output to protobuf # For streaming, always send chunks if request.stream: yield self._chunk_response(output) # Send complete response when finished if output.finished: yield self._complete_response(output) except ValueError as e: # Invalid request error (equiv to 400). await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) except Exception as e: logger.exception("Error in Generate for request %s", request_id) await context.abort(grpc.StatusCode.INTERNAL, str(e)) async def Embed( self, request: vllm_engine_pb2.EmbedRequest, context: grpc.aio.ServicerContext, ) -> vllm_engine_pb2.EmbedResponse: """ Handle embedding requests. TODO: Implement in Phase 4 Args: request: The EmbedRequest protobuf context: gRPC context Returns: EmbedResponse protobuf """ logger.warning("Embed RPC not yet implemented") await context.abort( grpc.StatusCode.UNIMPLEMENTED, "Embed RPC not yet implemented" ) async def HealthCheck( self, request: vllm_engine_pb2.HealthCheckRequest, context: grpc.aio.ServicerContext, ) -> vllm_engine_pb2.HealthCheckResponse: """ Handle health check requests. Args: request: The HealthCheckRequest protobuf context: gRPC context Returns: HealthCheckResponse protobuf """ is_healthy = not self.async_llm.errored message = "Health" if is_healthy else "Engine is not alive" logger.debug("HealthCheck request: healthy=%s, message=%s", is_healthy, message) return vllm_engine_pb2.HealthCheckResponse(healthy=is_healthy, message=message) async def Abort( self, request: vllm_engine_pb2.AbortRequest, context: grpc.aio.ServicerContext, ) -> vllm_engine_pb2.AbortResponse: """ Out-of-band abort requests. Args: request: The AbortRequest protobuf context: gRPC context Returns: AbortResponse protobuf """ request_ids = request.request_ids logger.debug("Abort requests: %s", request_ids) await self.async_llm.abort(request_ids) return vllm_engine_pb2.AbortResponse() async def GetModelInfo( self, request: vllm_engine_pb2.GetModelInfoRequest, context: grpc.aio.ServicerContext, ) -> vllm_engine_pb2.GetModelInfoResponse: """ Handle model info requests. Args: request: The GetModelInfoRequest protobuf context: gRPC context Returns: GetModelInfoResponse protobuf """ model_config = self.async_llm.model_config return vllm_engine_pb2.GetModelInfoResponse( model_path=model_config.model, is_generation=model_config.runner_type == "generate", max_context_length=model_config.max_model_len, vocab_size=model_config.get_vocab_size(), supports_vision=model_config.is_multimodal_model, ) async def GetServerInfo( self, request: vllm_engine_pb2.GetServerInfoRequest, context: grpc.aio.ServicerContext, ) -> vllm_engine_pb2.GetServerInfoResponse: """ Handle server info requests. Args: request: The GetServerInfoRequest protobuf context: gRPC context Returns: GetServerInfoResponse protobuf """ num_requests = self.async_llm.output_processor.get_num_unfinished_requests() return vllm_engine_pb2.GetServerInfoResponse( active_requests=num_requests, is_paused=False, # TODO last_receive_timestamp=time.time(), # TODO looks wrong? uptime_seconds=time.time() - self.start_time, server_type="vllm-grpc", ) # ========== Helper methods ========== @staticmethod def _sampling_params_from_proto( params: vllm_engine_pb2.SamplingParams, stream: bool = True ) -> SamplingParams: """ Convert protobuf SamplingParams to vLLM SamplingParams. Args: params: Protobuf SamplingParams message stream: Whether streaming is enabled Returns: vLLM SamplingParams with detokenize=False and structured_outputs """ # Build stop sequences stop = list(params.stop) if params.stop else None stop_token_ids = list(params.stop_token_ids) if params.stop_token_ids else None # Handle structured outputs constraints structured_outputs = None constraint_field = params.WhichOneof("constraint") if constraint_field: if constraint_field == "json_schema": structured_outputs = StructuredOutputsParams(json=params.json_schema) elif constraint_field == "regex": structured_outputs = StructuredOutputsParams(regex=params.regex) elif constraint_field == "grammar": structured_outputs = StructuredOutputsParams(grammar=params.grammar) elif constraint_field == "structural_tag": structured_outputs = StructuredOutputsParams( structural_tag=params.structural_tag ) elif constraint_field == "json_object": structured_outputs = StructuredOutputsParams( json_object=params.json_object ) elif constraint_field == "choice": structured_outputs = StructuredOutputsParams( choice=list(params.choice.choices) ) # Create SamplingParams # output_kind=DELTA: Return only new tokens in each chunk (for streaming) return SamplingParams( temperature=params.temperature if params.HasField("temperature") else 1.0, top_p=params.top_p if params.top_p != 0.0 else 1.0, top_k=params.top_k, min_p=params.min_p, frequency_penalty=params.frequency_penalty, presence_penalty=params.presence_penalty, repetition_penalty=params.repetition_penalty if params.repetition_penalty != 0.0 else 1.0, max_tokens=params.max_tokens if params.HasField("max_tokens") else None, min_tokens=params.min_tokens, stop=stop, stop_token_ids=stop_token_ids, skip_special_tokens=params.skip_special_tokens, spaces_between_special_tokens=params.spaces_between_special_tokens, ignore_eos=params.ignore_eos, n=params.n if params.n > 0 else 1, logprobs=params.logprobs if params.HasField("logprobs") else None, prompt_logprobs=params.prompt_logprobs if params.HasField("prompt_logprobs") else None, seed=params.seed if params.HasField("seed") else None, include_stop_str_in_output=params.include_stop_str_in_output, logit_bias=dict(params.logit_bias) if params.logit_bias else None, structured_outputs=structured_outputs, # detokenize must be True if stop strings are used detokenize=bool(stop), output_kind=RequestOutputKind.DELTA if stream else RequestOutputKind.FINAL_ONLY, ) @staticmethod def _tokenization_kwargs_from_proto( params: vllm_engine_pb2.SamplingParams, ) -> dict[str, int] | None: if params.HasField("truncate_prompt_tokens"): return {"truncate_prompt_tokens": params.truncate_prompt_tokens} return None @staticmethod def _chunk_response(output: RequestOutput) -> vllm_engine_pb2.GenerateResponse: """ Build a streaming chunk response from vLLM output. When output_kind=DELTA, vLLM returns only new tokens automatically. Args: output: vLLM RequestOutput (with delta tokens when output_kind=DELTA) Returns: GenerateResponse with chunk field set """ # Get the completion output (first one if n > 1) completion = output.outputs[0] if output.outputs else None if completion is None: # Empty chunk return vllm_engine_pb2.GenerateResponse( chunk=vllm_engine_pb2.GenerateStreamChunk( token_ids=[], prompt_tokens=0, completion_tokens=0, cached_tokens=0, ), ) # When output_kind=DELTA, completion.token_ids contains only new tokens # vLLM handles the delta logic internally # completion_tokens = delta count (client will accumulate) return vllm_engine_pb2.GenerateResponse( chunk=vllm_engine_pb2.GenerateStreamChunk( token_ids=completion.token_ids, prompt_tokens=len(output.prompt_token_ids) if output.prompt_token_ids else 0, completion_tokens=len(completion.token_ids), # Delta count cached_tokens=output.num_cached_tokens, ), ) @staticmethod def _complete_response(output: RequestOutput) -> vllm_engine_pb2.GenerateResponse: """ Build a final completion response from vLLM output. Args: output: vLLM RequestOutput (finished=True) Returns: GenerateResponse with complete field set """ # Get the completion output (first one if n > 1) completion = output.outputs[0] if output.outputs else None if completion is None: # Empty completion return vllm_engine_pb2.GenerateResponse( complete=vllm_engine_pb2.GenerateComplete( output_ids=[], finish_reason="error", prompt_tokens=0, completion_tokens=0, cached_tokens=0, ), ) # Build complete response # When streaming (DELTA mode): completion.token_ids will be empty/last delta # When non-streaming (FINAL_ONLY mode): completion.token_ids has all tokens # Client will accumulate token counts for streaming return vllm_engine_pb2.GenerateResponse( complete=vllm_engine_pb2.GenerateComplete( output_ids=completion.token_ids, finish_reason=completion.finish_reason or "stop", prompt_tokens=len(output.prompt_token_ids) if output.prompt_token_ids else 0, completion_tokens=len(completion.token_ids), cached_tokens=output.num_cached_tokens, ), ) async def serve_grpc(args: argparse.Namespace): """ Main serving function. Args: args: Parsed command line arguments """ log_version_and_model(logger, VLLM_VERSION, args.model) logger.info("vLLM gRPC server args: %s", args) start_time = time.time() # Create engine args engine_args = AsyncEngineArgs.from_cli_args(args) # Build vLLM config vllm_config = engine_args.create_engine_config( usage_context=UsageContext.OPENAI_API_SERVER ) # Create AsyncLLM async_llm = AsyncLLM.from_vllm_config( vllm_config=vllm_config, usage_context=UsageContext.OPENAI_API_SERVER, enable_log_requests=args.enable_log_requests, disable_log_stats=args.disable_log_stats_server, ) # Create servicer servicer = VllmEngineServicer(async_llm, start_time) # Create gRPC server server = grpc.aio.server( options=[ ("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1), ], ) # Add servicer to server vllm_engine_pb2_grpc.add_VllmEngineServicer_to_server(servicer, server) # Enable reflection for grpcurl and other tools service_names = ( vllm_engine_pb2.DESCRIPTOR.services_by_name["VllmEngine"].full_name, reflection.SERVICE_NAME, ) reflection.enable_server_reflection(service_names, server) # Bind to address address = f"{args.host}:{args.port}" server.add_insecure_port(address) # Start server await server.start() logger.info("vLLM gRPC server started on %s", address) logger.info("Server is ready to accept requests") # Handle shutdown signals loop = asyncio.get_running_loop() stop_event = asyncio.Event() def signal_handler(): logger.info("Received shutdown signal") stop_event.set() for sig in (signal.SIGTERM, signal.SIGINT): loop.add_signal_handler(sig, signal_handler) # Serve until shutdown signal try: await stop_event.wait() except KeyboardInterrupt: logger.info("Interrupted by user") finally: logger.info("Shutting down vLLM gRPC server...") # Stop gRPC server await server.stop(grace=5.0) logger.info("gRPC server stopped") # Shutdown AsyncLLM async_llm.shutdown() logger.info("AsyncLLM engine stopped") logger.info("Shutdown complete") def main(): """Main entry point.""" parser = FlexibleArgumentParser( description="vLLM gRPC Server", ) # Server args parser.add_argument( "--host", type=str, default="0.0.0.0", help="Host to bind gRPC server to", ) parser.add_argument( "--port", type=int, default=50051, help="Port to bind gRPC server to", ) parser.add_argument( "--disable-log-stats-server", action="store_true", help="Disable stats logging on server side", ) # Add vLLM engine args parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() # Run server try: uvloop.run(serve_grpc(args)) except Exception as e: logger.exception("Server failed: %s", e) sys.exit(1) if __name__ == "__main__": main()