Files

159 lines
6.2 KiB
Python
Raw Permalink Normal View History

2026-04-02 04:53:13 +00:00
import argparse
import copy
import dataclasses
import functools
import json
import sys
import threading
import warnings
from dataclasses import MISSING, dataclass, fields, is_dataclass
from itertools import permutations
from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional,
Type, TypeVar, Union, cast, get_args, get_origin)
import regex as re
import torch
from pydantic import TypeAdapter, ValidationError
from typing_extensions import TypeIs, deprecated
import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.plugins import load_general_plugins
from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, GiB_bytes, get_ip,
is_in_ray_actor)
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import CpuArchEnum, current_platform
logger = init_logger(__name__)
def _set_default_args(self, usage_context: UsageContext,
model_config: ModelConfig) -> None:
"""Set Default Arguments for V1 Engine."""
# V1 always uses chunked prefills and prefix caching
# for non-pooling tasks.
# For pooling tasks the default is False
self.enable_chunked_prefill = False
self.enable_prefix_caching = False
if model_config.runner_type != "pooling":
# TODO: When prefix caching supports prompt embeds inputs, this
# check can be removed.
if (self.enable_prompt_embeds
and self.enable_prefix_caching is not False):
logger.warning(
"--enable-prompt-embeds and --enable-prefix-caching "
"are not supported together in V1. Prefix caching has "
"been disabled.")
# V1 should use the new scheduler by default.
# Swap it only if this arg is set to the original V0 default
if self.scheduler_cls == EngineArgs.scheduler_cls:
self.scheduler_cls = "vllm.v1.core.sched.scheduler.Scheduler"
# When no user override, set the default values based on the usage
# context.
# Use different default values for different hardware.
# Try to query the device name on the current platform. If it fails,
# it may be because the platform that imports vLLM is not the same
# as the platform that vLLM is running on (e.g. the case of scaling
# vLLM with Ray) and has no GPUs. In this case we use the default
# values for non-H100/H200 GPUs.
try:
device_memory = current_platform.get_device_total_memory()
device_name = current_platform.get_device_name().lower()
except Exception:
# This is only used to set default_max_num_batched_tokens
device_memory = 0
# NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
# throughput, see PR #17885 for more details.
# So here we do an extra device name check to prevent such regression.
from vllm.usage.usage_lib import UsageContext
if device_memory >= 70 * GiB_bytes and "a100" not in device_name:
# For GPUs like H100 and MI300x, use larger default values.
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 16384,
UsageContext.OPENAI_API_SERVER: 8192,
}
default_max_num_seqs = {
UsageContext.LLM_CLASS: 1024,
UsageContext.OPENAI_API_SERVER: 1024,
}
else:
# TODO(woosuk): Tune the default values for other hardware.
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 8192,
UsageContext.OPENAI_API_SERVER: 2048,
}
default_max_num_seqs = {
UsageContext.LLM_CLASS: 4,
UsageContext.OPENAI_API_SERVER: 4,
}
# tpu specific default values.
if current_platform.is_tpu():
default_max_num_batched_tokens_tpu = {
UsageContext.LLM_CLASS: {
'V6E': 2048,
'V5E': 1024,
'V5P': 512,
},
UsageContext.OPENAI_API_SERVER: {
'V6E': 1024,
'V5E': 512,
'V5P': 256,
}
}
# cpu specific default values.
if current_platform.is_cpu():
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 4096 * world_size,
UsageContext.OPENAI_API_SERVER: 2048 * world_size,
}
default_max_num_seqs = {
UsageContext.LLM_CLASS: 256 * world_size,
UsageContext.OPENAI_API_SERVER: 128 * world_size,
}
use_context_value = usage_context.value if usage_context else None
if (self.max_num_batched_tokens is None
and usage_context in default_max_num_batched_tokens):
if current_platform.is_tpu():
chip_name = current_platform.get_device_name()
if chip_name in default_max_num_batched_tokens_tpu[
usage_context]:
self.max_num_batched_tokens = \
default_max_num_batched_tokens_tpu[
usage_context][chip_name]
else:
self.max_num_batched_tokens = \
default_max_num_batched_tokens[usage_context]
else:
if not self.enable_chunked_prefill:
self.max_num_batched_tokens = model_config.max_model_len
else:
self.max_num_batched_tokens = \
default_max_num_batched_tokens[usage_context]
logger.debug(
"Setting max_num_batched_tokens to %d for %s usage context.",
self.max_num_batched_tokens, use_context_value)
if (self.max_num_seqs is None
and usage_context in default_max_num_seqs):
self.max_num_seqs = min(default_max_num_seqs[usage_context],
self.max_num_batched_tokens or sys.maxsize)
logger.debug("Setting max_num_seqs to %d for %s usage context.",
self.max_num_seqs, use_context_value)