# SPDX-License-Identifier: Apache-2.0 import asyncio import tempfile from collections.abc import Awaitable from http import HTTPStatus from io import StringIO from typing import Callable, Optional import aiohttp import torch from prometheus_client import start_http_server from tqdm import tqdm from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.logger import RequestLogger, logger # yapf: disable from vllm.entrypoints.openai.protocol import (BatchRequestInput, BatchRequestOutput, BatchResponseData, ChatCompletionResponse, EmbeddingResponse, ErrorResponse, ScoreResponse) # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_models import (BaseModelPath, OpenAIServingModels) from vllm.entrypoints.openai.serving_score import ServingScores from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, random_uuid from vllm.version import __version__ as VLLM_VERSION def parse_args(): parser = FlexibleArgumentParser( description="vLLM OpenAI-Compatible batch runner.") parser.add_argument( "-i", "--input-file", required=True, type=str, help= "The path or url to a single input file. Currently supports local file " "paths, or the http protocol (http or https). If a URL is specified, " "the file should be available via HTTP GET.") parser.add_argument( "-o", "--output-file", required=True, type=str, help="The path or url to a single output file. Currently supports " "local file paths, or web (http or https) urls. If a URL is specified," " the file should be available via HTTP PUT.") parser.add_argument( "--output-tmp-dir", type=str, default=None, help="The directory to store the output file before uploading it " "to the output URL.", ) parser.add_argument("--response-role", type=nullable_str, default="assistant", help="The role name to return if " "`request.add_generation_prompt=True`.") parser = AsyncEngineArgs.add_cli_args(parser) parser.add_argument('--max-log-len', type=int, default=None, help='Max number of prompt characters or prompt ' 'ID numbers being printed in log.' '\n\nDefault: Unlimited') parser.add_argument("--enable-metrics", action="store_true", help="Enable Prometheus metrics") parser.add_argument( "--url", type=str, default="0.0.0.0", help="URL to the Prometheus metrics server " "(only needed if enable-metrics is set).", ) parser.add_argument( "--port", type=int, default=8000, help="Port number for the Prometheus metrics server " "(only needed if enable-metrics is set).", ) parser.add_argument( "--enable-prompt-tokens-details", action='store_true', default=False, help="If set to True, enable prompt_tokens_details in usage.") return parser.parse_args() # explicitly use pure text format, with a newline at the end # this makes it impossible to see the animation in the progress bar # but will avoid messing up with ray or multiprocessing, which wraps # each line of output with some prefix. _BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501 class BatchProgressTracker: def __init__(self): self._total = 0 self._pbar: Optional[tqdm] = None def submitted(self): self._total += 1 def completed(self): if self._pbar: self._pbar.update() def pbar(self) -> tqdm: enable_tqdm = not torch.distributed.is_initialized( ) or torch.distributed.get_rank() == 0 self._pbar = tqdm(total=self._total, unit="req", desc="Running batch", mininterval=5, disable=not enable_tqdm, bar_format=_BAR_FORMAT) return self._pbar async def read_file(path_or_url: str) -> str: if path_or_url.startswith("http://") or path_or_url.startswith("https://"): async with aiohttp.ClientSession() as session, \ session.get(path_or_url) as resp: return await resp.text() else: with open(path_or_url, encoding="utf-8") as f: return f.read() async def write_local_file(output_path: str, batch_outputs: list[BatchRequestOutput]) -> None: """ Write the responses to a local file. output_path: The path to write the responses to. batch_outputs: The list of batch outputs to write. """ # We should make this async, but as long as run_batch runs as a # standalone program, blocking the event loop won't effect performance. with open(output_path, "w", encoding="utf-8") as f: for o in batch_outputs: print(o.model_dump_json(), file=f) async def upload_data(output_url: str, data_or_file: str, from_file: bool) -> None: """ Upload a local file to a URL. output_url: The URL to upload the file to. data_or_file: Either the data to upload or the path to the file to upload. from_file: If True, data_or_file is the path to the file to upload. """ # Timeout is a common issue when uploading large files. # We retry max_retries times before giving up. max_retries = 5 # Number of seconds to wait before retrying. delay = 5 for attempt in range(1, max_retries + 1): try: # We increase the timeout to 1000 seconds to allow # for large files (default is 300). async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout( total=1000)) as session: if from_file: with open(data_or_file, "rb") as file: async with session.put(output_url, data=file) as response: if response.status != 200: raise Exception(f"Failed to upload file.\n" f"Status: {response.status}\n" f"Response: {response.text()}") else: async with session.put(output_url, data=data_or_file) as response: if response.status != 200: raise Exception(f"Failed to upload data.\n" f"Status: {response.status}\n" f"Response: {response.text()}") except Exception as e: if attempt < max_retries: logger.error( f"Failed to upload data (attempt {attempt}). " f"Error message: {str(e)}.\nRetrying in {delay} seconds..." ) await asyncio.sleep(delay) else: raise Exception(f"Failed to upload data (attempt {attempt}). " f"Error message: {str(e)}.") from e async def write_file(path_or_url: str, batch_outputs: list[BatchRequestOutput], output_tmp_dir: str) -> None: """ Write batch_outputs to a file or upload to a URL. path_or_url: The path or URL to write batch_outputs to. batch_outputs: The list of batch outputs to write. output_tmp_dir: The directory to store the output file before uploading it to the output URL. """ if path_or_url.startswith("http://") or path_or_url.startswith("https://"): if output_tmp_dir is None: logger.info("Writing outputs to memory buffer") output_buffer = StringIO() for o in batch_outputs: print(o.model_dump_json(), file=output_buffer) output_buffer.seek(0) logger.info("Uploading outputs to %s", path_or_url) await upload_data( path_or_url, output_buffer.read().strip().encode("utf-8"), from_file=False, ) else: # Write responses to a temporary file and then upload it to the URL. with tempfile.NamedTemporaryFile( mode="w", encoding="utf-8", dir=output_tmp_dir, prefix="tmp_batch_output_", suffix=".jsonl", ) as f: logger.info("Writing outputs to temporary local file %s", f.name) await write_local_file(f.name, batch_outputs) logger.info("Uploading outputs to %s", path_or_url) await upload_data(path_or_url, f.name, from_file=True) else: logger.info("Writing outputs to local file %s", path_or_url) await write_local_file(path_or_url, batch_outputs) def make_error_request_output(request: BatchRequestInput, error_msg: str) -> BatchRequestOutput: batch_output = BatchRequestOutput( id=f"vllm-{random_uuid()}", custom_id=request.custom_id, response=BatchResponseData( status_code=HTTPStatus.BAD_REQUEST, request_id=f"vllm-batch-{random_uuid()}", ), error=error_msg, ) return batch_output async def make_async_error_request_output( request: BatchRequestInput, error_msg: str) -> BatchRequestOutput: return make_error_request_output(request, error_msg) async def run_request(serving_engine_func: Callable, request: BatchRequestInput, tracker: BatchProgressTracker) -> BatchRequestOutput: response = await serving_engine_func(request.body) if isinstance(response, (ChatCompletionResponse, EmbeddingResponse, ScoreResponse)): batch_output = BatchRequestOutput( id=f"vllm-{random_uuid()}", custom_id=request.custom_id, response=BatchResponseData( body=response, request_id=f"vllm-batch-{random_uuid()}"), error=None, ) elif isinstance(response, ErrorResponse): batch_output = BatchRequestOutput( id=f"vllm-{random_uuid()}", custom_id=request.custom_id, response=BatchResponseData( status_code=response.code, request_id=f"vllm-batch-{random_uuid()}"), error=response, ) else: batch_output = make_error_request_output( request, error_msg="Request must not be sent in stream mode") tracker.completed() return batch_output async def main(args): if args.served_model_name is not None: served_model_names = args.served_model_name else: served_model_names = [args.model] engine_args = AsyncEngineArgs.from_cli_args(args) engine = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER) model_config = await engine.get_model_config() base_model_paths = [ BaseModelPath(name=name, model_path=args.model) for name in served_model_names ] if args.disable_log_requests: request_logger = None else: request_logger = RequestLogger(max_log_len=args.max_log_len) # Create the openai serving objects. openai_serving_models = OpenAIServingModels( engine_client=engine, model_config=model_config, base_model_paths=base_model_paths, lora_modules=None, prompt_adapters=None, ) openai_serving_chat = OpenAIServingChat( engine, model_config, openai_serving_models, args.response_role, request_logger=request_logger, chat_template=None, chat_template_content_format="auto", enable_prompt_tokens_details=args.enable_prompt_tokens_details, ) if model_config.runner_type == "generate" else None openai_serving_embedding = OpenAIServingEmbedding( engine, model_config, openai_serving_models, request_logger=request_logger, chat_template=None, chat_template_content_format="auto", ) if model_config.task == "embed" else None openai_serving_scores = (ServingScores( engine, model_config, openai_serving_models, request_logger=request_logger, ) if model_config.task == "score" else None) tracker = BatchProgressTracker() logger.info("Reading batch from %s...", args.input_file) # Submit all requests in the file to the engine "concurrently". response_futures: list[Awaitable[BatchRequestOutput]] = [] for request_json in (await read_file(args.input_file)).strip().split("\n"): # Skip empty lines. request_json = request_json.strip() if not request_json: continue request = BatchRequestInput.model_validate_json(request_json) # Determine the type of request and run it. if request.url == "/v1/chat/completions": chat_handler_fn = (None if openai_serving_chat is None else openai_serving_chat.create_chat_completion) if chat_handler_fn is None: response_futures.append( make_async_error_request_output( request, error_msg= "The model does not support Chat Completions API", )) continue response_futures.append( run_request(chat_handler_fn, request, tracker)) tracker.submitted() elif request.url == "/v1/embeddings": embed_handler_fn = (None if openai_serving_embedding is None else openai_serving_embedding.create_embedding) if embed_handler_fn is None: response_futures.append( make_async_error_request_output( request, error_msg="The model does not support Embeddings API", )) continue response_futures.append( run_request(embed_handler_fn, request, tracker)) tracker.submitted() elif request.url == "/v1/score": score_handler_fn = (None if openai_serving_scores is None else openai_serving_scores.create_score) if score_handler_fn is None: response_futures.append( make_async_error_request_output( request, error_msg="The model does not support Scores API", )) continue response_futures.append( run_request(score_handler_fn, request, tracker)) tracker.submitted() else: response_futures.append( make_async_error_request_output( request, error_msg= "Only /v1/chat/completions, /v1/embeddings, and /v1/score " "are supported in the batch endpoint.", )) with tracker.pbar(): responses = await asyncio.gather(*response_futures) await write_file(args.output_file, responses, args.output_tmp_dir) if __name__ == "__main__": args = parse_args() logger.info("vLLM batch processing API version %s", VLLM_VERSION) logger.info("args: %s", args) # Start the Prometheus metrics server. LLMEngine uses the Prometheus client # to publish metrics at the /metrics endpoint. if args.enable_metrics: logger.info("Prometheus metrics enabled") start_http_server(port=args.port, addr=args.url) else: logger.info("Prometheus metrics disabled") asyncio.run(main(args))